reformat code

This commit is contained in:
hdt3213
2021-04-03 20:14:12 +08:00
parent bf913a5aca
commit bcf0cd5e92
54 changed files with 4887 additions and 4896 deletions

View File

@@ -2,11 +2,13 @@
[中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md) [中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md)
`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent middleware using golang. `Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent
middleware using golang.
Please be advised, NEVER think about using this in production environment. Please be advised, NEVER think about using this in production environment.
This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence and server side cluster mode. This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence
and server side cluster mode.
If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html). If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html).
@@ -22,7 +24,7 @@ You could use redis-cli or other redis client to connect godis server, which lis
The program will try to read config file path from environment variable `CONFIG`. The program will try to read config file path from environment variable `CONFIG`.
If environment variable is not set, then the program try to read `redis.conf` in the working directory. If environment variable is not set, then the program try to read `redis.conf` in the working directory.
If there is no such file, then the program will run with default config. If there is no such file, then the program will run with default config.
@@ -35,8 +37,7 @@ peers localhost:7379,localhost:7389 // other node in cluster
self localhost:6399 // self address self localhost:6399 // self address
``` ```
We provide node1.conf and node2.conf for demonstration. We provide node1.conf and node2.conf for demonstration. use following command line to start a two-node-cluster:
use following command line to start a two-node-cluster:
```bash ```bash
CONFIG=node1.conf ./godis-darwin & CONFIG=node1.conf ./godis-darwin &
@@ -151,7 +152,7 @@ Supported Commands:
If you want to read my code in this repository, here is a simple guidance. If you want to read my code in this repository, here is a simple guidance.
- cmd: only the entry point - cmd: only the entry point
- config: config parser - config: config parser
- interface: some interface definitions - interface: some interface definitions
- lib: some utils, such as logger, sync utils and wildcard - lib: some utils, such as logger, sync utils and wildcard
@@ -167,7 +168,7 @@ I suggest focusing on the following directories:
- sortedset: a sorted set implements based on skiplist - sortedset: a sorted set implements based on skiplist
- db: the implements of the redis db - db: the implements of the redis db
- db.go: the basement of database - db.go: the basement of database
- router.go: it find handler for commands - router.go: it find handler for commands
- keys.go: handlers for keys commands - keys.go: handlers for keys commands
- string.go: handlers for string commands - string.go: handlers for string commands
- list.go: handlers for list commands - list.go: handlers for list commands
@@ -176,7 +177,7 @@ I suggest focusing on the following directories:
- sortedset.go: handlers for sorted set commands - sortedset.go: handlers for sorted set commands
- pubsub.go: implements of publish / subscribe - pubsub.go: implements of publish / subscribe
- aof.go: implements of AOF persistence and rewrite - aof.go: implements of AOF persistence and rewrite
# License # License
This project is licensed under the [GPL license](https://github.com/HDT3213/godis/blob/master/LICENSE). This project is licensed under the [GPL license](https://github.com/HDT3213/godis/blob/master/LICENSE).

View File

@@ -2,8 +2,8 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试
**请注意:不要在生产环境使用使用此项目** **请注意:不要在生产环境使用使用此项目**
Godis 实现了 Redis 的大多数功能包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于 Godis 的信息。 Godis 实现了 Redis 的大多数功能包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于
Godis 的信息。
# 运行 Godis # 运行 Godis
@@ -149,19 +149,19 @@ redis-cli -p 6399
- tcp: tcp 服务器实现 - tcp: tcp 服务器实现
- redis: redis 协议解析器 - redis: redis 协议解析器
- datastruct: redis 的各类数据结构实现 - datastruct: redis 的各类数据结构实现
- dict: hash 表 - dict: hash 表
- list: 链表 - list: 链表
- lock: 用于锁定 key 的锁组件 - lock: 用于锁定 key 的锁组件
- set 基于hash表的集合 - set 基于hash表的集合
- sortedset: 基于跳表实现的有序集合 - sortedset: 基于跳表实现的有序集合
- db: redis 存储引擎实现 - db: redis 存储引擎实现
- db.go: 引擎的基础功能 - db.go: 引擎的基础功能
- router.go: 将命令路由给响应的处理函数 - router.go: 将命令路由给响应的处理函数
- keys.go: del、ttl、expire 等通用命令实现 - keys.go: del、ttl、expire 等通用命令实现
- string.go: get、set 等字符串命令实现 - string.go: get、set 等字符串命令实现
- list.go: lpush、lindex 等列表命令实现 - list.go: lpush、lindex 等列表命令实现
- hash.go: hget、hset 等哈希表命令实现 - hash.go: hget、hset 等哈希表命令实现
- set.go: sadd 等集合命令实现 - set.go: sadd 等集合命令实现
- sortedset.go: zadd 等有序集合命令实现 - sortedset.go: zadd 等有序集合命令实现
- pubsub.go: 发布订阅命令实现 - pubsub.go: 发布订阅命令实现
- aof.go: aof持久化实现 - aof.go: aof持久化实现

View File

@@ -1,45 +1,45 @@
package cluster package cluster
import ( import (
"context" "context"
"errors" "errors"
"github.com/HDT3213/godis/src/redis/client" "github.com/HDT3213/godis/src/redis/client"
"github.com/jolestar/go-commons-pool/v2" "github.com/jolestar/go-commons-pool/v2"
) )
type ConnectionFactory struct { type ConnectionFactory struct {
Peer string Peer string
} }
func (f *ConnectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, error) { func (f *ConnectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, error) {
c, err := client.MakeClient(f.Peer) c, err := client.MakeClient(f.Peer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.Start() c.Start()
return pool.NewPooledObject(c), nil return pool.NewPooledObject(c), nil
} }
func (f *ConnectionFactory) DestroyObject(ctx context.Context, object *pool.PooledObject) error { func (f *ConnectionFactory) DestroyObject(ctx context.Context, object *pool.PooledObject) error {
c, ok := object.Object.(*client.Client) c, ok := object.Object.(*client.Client)
if !ok { if !ok {
return errors.New("type mismatch") return errors.New("type mismatch")
} }
c.Close() c.Close()
return nil return nil
} }
func (f *ConnectionFactory) ValidateObject(ctx context.Context, object *pool.PooledObject) bool { func (f *ConnectionFactory) ValidateObject(ctx context.Context, object *pool.PooledObject) bool {
// do validate // do validate
return true return true
} }
func (f *ConnectionFactory) ActivateObject(ctx context.Context, object *pool.PooledObject) error { func (f *ConnectionFactory) ActivateObject(ctx context.Context, object *pool.PooledObject) error {
// do activate // do activate
return nil return nil
} }
func (f *ConnectionFactory) PassivateObject(ctx context.Context, object *pool.PooledObject) error { func (f *ConnectionFactory) PassivateObject(ctx context.Context, object *pool.PooledObject) error {
// do passivate // do passivate
return nil return nil
} }

View File

@@ -1,99 +1,99 @@
package cluster package cluster
import ( import (
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"strconv" "strconv"
) )
func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command") return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
} }
keys := make([]string, len(args)-1) keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ { for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i]) keys[i-1] = string(args[i])
} }
groupMap := cluster.groupBy(keys) groupMap := cluster.groupBy(keys)
if len(groupMap) == 1 { // do fast if len(groupMap) == 1 { // do fast
for peer, group := range groupMap { // only one group for peer, group := range groupMap { // only one group
return cluster.Relay(peer, c, makeArgs("DEL", group...)) return cluster.Relay(peer, c, makeArgs("DEL", group...))
} }
} }
// prepare // prepare
var errReply redis.Reply var errReply redis.Reply
txId := cluster.idGenerator.NextId() txId := cluster.idGenerator.NextId()
txIdStr := strconv.FormatInt(txId, 10) txIdStr := strconv.FormatInt(txId, 10)
rollback := false rollback := false
for peer, group := range groupMap { for peer, group := range groupMap {
args := []string{txIdStr} args := []string{txIdStr}
args = append(args, group...) args = append(args, group...)
var resp redis.Reply var resp redis.Reply
if peer == cluster.self { if peer == cluster.self {
resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...)) resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...))
} else { } else {
resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...)) resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...))
} }
if reply.IsErrorReply(resp) { if reply.IsErrorReply(resp) {
errReply = resp errReply = resp
rollback = true rollback = true
break break
} }
} }
var respList []redis.Reply var respList []redis.Reply
if rollback { if rollback {
// rollback // rollback
RequestRollback(cluster, c, txId, groupMap) RequestRollback(cluster, c, txId, groupMap)
} else { } else {
// commit // commit
respList, errReply = RequestCommit(cluster, c, txId, groupMap) respList, errReply = RequestCommit(cluster, c, txId, groupMap)
if errReply != nil { if errReply != nil {
rollback = true rollback = true
} }
} }
if !rollback { if !rollback {
var deleted int64 = 0 var deleted int64 = 0
for _, resp := range respList { for _, resp := range respList {
intResp := resp.(*reply.IntReply) intResp := resp.(*reply.IntReply)
deleted += intResp.Code deleted += intResp.Code
} }
return reply.MakeIntReply(int64(deleted)) return reply.MakeIntReply(int64(deleted))
} }
return errReply return errReply
} }
// args: PrepareDel id keys... // args: PrepareDel id keys...
func PrepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func PrepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 3 { if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command") return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command")
} }
txId := string(args[1]) txId := string(args[1])
keys := make([]string, 0, len(args)-2) keys := make([]string, 0, len(args)-2)
for i := 2; i < len(args); i++ { for i := 2; i < len(args); i++ {
arg := args[i] arg := args[i]
keys = append(keys, string(arg)) keys = append(keys, string(arg))
} }
txArgs := makeArgs("DEL", keys...) // actual args for cluster.db txArgs := makeArgs("DEL", keys...) // actual args for cluster.db
tx := NewTransaction(cluster, c, txId, txArgs, keys) tx := NewTransaction(cluster, c, txId, txArgs, keys)
cluster.transactions.Put(txId, tx) cluster.transactions.Put(txId, tx)
err := tx.prepare() err := tx.prepare()
if err != nil { if err != nil {
return reply.MakeErrReply(err.Error()) return reply.MakeErrReply(err.Error())
} }
return &reply.OkReply{} return &reply.OkReply{}
} }
// invoker should provide lock // invoker should provide lock
func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply { func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
keys := make([]string, len(tx.args)) keys := make([]string, len(tx.args))
for i, v := range tx.args { for i, v := range tx.args {
keys[i] = string(v) keys[i] = string(v)
} }
keys = keys[1:] keys = keys[1:]
deleted := cluster.db.Removes(keys...) deleted := cluster.db.Removes(keys...)
if deleted > 0 { if deleted > 0 {
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
} }
return reply.MakeIntReply(int64(deleted)) return reply.MakeIntReply(int64(deleted))
} }

View File

@@ -1,85 +1,85 @@
package idgenerator package idgenerator
import ( import (
"hash/fnv" "hash/fnv"
"log" "log"
"sync" "sync"
"time" "time"
) )
const ( const (
workerIdBits int64 = 5 workerIdBits int64 = 5
datacenterIdBits int64 = 5 datacenterIdBits int64 = 5
sequenceBits int64 = 12 sequenceBits int64 = 12
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits)) maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits)) maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits)) maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
timeLeft uint8 = 22 timeLeft uint8 = 22
dataLeft uint8 = 17 dataLeft uint8 = 17
workLeft uint8 = 12 workLeft uint8 = 12
twepoch int64 = 1525705533000 twepoch int64 = 1525705533000
) )
type IdGenerator struct { type IdGenerator struct {
mu *sync.Mutex mu *sync.Mutex
lastStamp int64 lastStamp int64
workerId int64 workerId int64
dataCenterId int64 dataCenterId int64
sequence int64 sequence int64
} }
func MakeGenerator(cluster string, node string) *IdGenerator { func MakeGenerator(cluster string, node string) *IdGenerator {
fnv64 := fnv.New64() fnv64 := fnv.New64()
_, _ = fnv64.Write([]byte(cluster)) _, _ = fnv64.Write([]byte(cluster))
dataCenterId := int64(fnv64.Sum64()) dataCenterId := int64(fnv64.Sum64())
fnv64.Reset() fnv64.Reset()
_, _ = fnv64.Write([]byte(node)) _, _ = fnv64.Write([]byte(node))
workerId := int64(fnv64.Sum64()) workerId := int64(fnv64.Sum64())
return &IdGenerator{ return &IdGenerator{
mu: &sync.Mutex{}, mu: &sync.Mutex{},
lastStamp: -1, lastStamp: -1,
dataCenterId: dataCenterId, dataCenterId: dataCenterId,
workerId: workerId, workerId: workerId,
sequence: 1, sequence: 1,
} }
} }
func (w *IdGenerator) getCurrentTime() int64 { func (w *IdGenerator) getCurrentTime() int64 {
return time.Now().UnixNano() / 1e6 return time.Now().UnixNano() / 1e6
} }
func (w *IdGenerator) NextId() int64 { func (w *IdGenerator) NextId() int64 {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
timestamp := w.getCurrentTime() timestamp := w.getCurrentTime()
if timestamp < w.lastStamp { if timestamp < w.lastStamp {
log.Fatal("can not generate id") log.Fatal("can not generate id")
} }
if w.lastStamp == timestamp { if w.lastStamp == timestamp {
w.sequence = (w.sequence + 1) & maxSequence w.sequence = (w.sequence + 1) & maxSequence
if w.sequence == 0 { if w.sequence == 0 {
for timestamp <= w.lastStamp { for timestamp <= w.lastStamp {
timestamp = w.getCurrentTime() timestamp = w.getCurrentTime()
} }
} }
} else { } else {
w.sequence = 0 w.sequence = 0
} }
w.lastStamp = timestamp w.lastStamp = timestamp
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
} }
func (w *IdGenerator) tilNextMillis() int64 { func (w *IdGenerator) tilNextMillis() int64 {
timestamp := w.getCurrentTime() timestamp := w.getCurrentTime()
if timestamp <= w.lastStamp { if timestamp <= w.lastStamp {
timestamp = w.getCurrentTime() timestamp = w.getCurrentTime()
} }
return timestamp return timestamp
} }

View File

@@ -1,159 +1,159 @@
package cluster package cluster
import ( import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/db" "github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"strconv" "strconv"
) )
func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
} }
keys := make([]string, len(args)-1) keys := make([]string, len(args)-1)
for i := 1; i < len(args); i++ { for i := 1; i < len(args); i++ {
keys[i-1] = string(args[i]) keys[i-1] = string(args[i])
} }
resultMap := make(map[string][]byte) resultMap := make(map[string][]byte)
groupMap := cluster.groupBy(keys) groupMap := cluster.groupBy(keys)
for peer, group := range groupMap { for peer, group := range groupMap {
resp := cluster.Relay(peer, c, makeArgs("MGET", group...)) resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
if reply.IsErrorReply(resp) { if reply.IsErrorReply(resp) {
errReply := resp.(reply.ErrorReply) errReply := resp.(reply.ErrorReply)
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error())) return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
} }
arrReply, _ := resp.(*reply.MultiBulkReply) arrReply, _ := resp.(*reply.MultiBulkReply)
for i, v := range arrReply.Args { for i, v := range arrReply.Args {
key := group[i] key := group[i]
resultMap[key] = v resultMap[key] = v
} }
} }
result := make([][]byte, len(keys)) result := make([][]byte, len(keys))
for i, k := range keys { for i, k := range keys {
result[i] = resultMap[k] result[i] = resultMap[k]
} }
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} }
// args: PrepareMSet id keys... // args: PrepareMSet id keys...
func PrepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func PrepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) < 3 { if len(args) < 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command")
} }
txId := string(args[1]) txId := string(args[1])
size := (len(args) - 2) / 2 size := (len(args) - 2) / 2
keys := make([]string, size) keys := make([]string, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
keys[i] = string(args[2*i+2]) keys[i] = string(args[2*i+2])
} }
txArgs := [][]byte{ txArgs := [][]byte{
[]byte("MSet"), []byte("MSet"),
} // actual args for cluster.db } // actual args for cluster.db
txArgs = append(txArgs, args[2:]...) txArgs = append(txArgs, args[2:]...)
tx := NewTransaction(cluster, c, txId, txArgs, keys) tx := NewTransaction(cluster, c, txId, txArgs, keys)
cluster.transactions.Put(txId, tx) cluster.transactions.Put(txId, tx)
err := tx.prepare() err := tx.prepare()
if err != nil { if err != nil {
return reply.MakeErrReply(err.Error()) return reply.MakeErrReply(err.Error())
} }
return &reply.OkReply{} return &reply.OkReply{}
} }
// invoker should provide lock // invoker should provide lock
func CommitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply { func CommitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
size := len(tx.args) / 2 size := len(tx.args) / 2
keys := make([]string, size) keys := make([]string, size)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
keys[i] = string(tx.args[2*i+1]) keys[i] = string(tx.args[2*i+1])
values[i] = tx.args[2*i+2] values[i] = tx.args[2*i+2]
} }
for i, key := range keys { for i, key := range keys {
value := values[i] value := values[i]
cluster.db.Put(key, &db.DataEntity{Data: value}) cluster.db.Put(key, &db.DataEntity{Data: value})
} }
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args)) cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
return &reply.OkReply{} return &reply.OkReply{}
} }
func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
argCount := len(args) - 1 argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 { if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
} }
size := argCount / 2 size := argCount / 2
keys := make([]string, size) keys := make([]string, size)
valueMap := make(map[string]string) valueMap := make(map[string]string)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
keys[i] = string(args[2*i+1]) keys[i] = string(args[2*i+1])
valueMap[keys[i]] = string(args[2*i+2]) valueMap[keys[i]] = string(args[2*i+2])
} }
groupMap := cluster.groupBy(keys) groupMap := cluster.groupBy(keys)
if len(groupMap) == 1 { // do fast if len(groupMap) == 1 { // do fast
for peer := range groupMap { for peer := range groupMap {
return cluster.Relay(peer, c, args) return cluster.Relay(peer, c, args)
} }
} }
//prepare //prepare
var errReply redis.Reply var errReply redis.Reply
txId := cluster.idGenerator.NextId() txId := cluster.idGenerator.NextId()
txIdStr := strconv.FormatInt(txId, 10) txIdStr := strconv.FormatInt(txId, 10)
rollback := false rollback := false
for peer, group := range groupMap { for peer, group := range groupMap {
peerArgs := []string{txIdStr} peerArgs := []string{txIdStr}
for _, k := range group { for _, k := range group {
peerArgs = append(peerArgs, k, valueMap[k]) peerArgs = append(peerArgs, k, valueMap[k])
} }
var resp redis.Reply var resp redis.Reply
if peer == cluster.self { if peer == cluster.self {
resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...)) resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...))
} else { } else {
resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...)) resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...))
} }
if reply.IsErrorReply(resp) { if reply.IsErrorReply(resp) {
errReply = resp errReply = resp
rollback = true rollback = true
break break
} }
} }
if rollback { if rollback {
// rollback // rollback
RequestRollback(cluster, c, txId, groupMap) RequestRollback(cluster, c, txId, groupMap)
} else { } else {
_, errReply = RequestCommit(cluster, c, txId, groupMap) _, errReply = RequestCommit(cluster, c, txId, groupMap)
rollback = errReply != nil rollback = errReply != nil
} }
if !rollback { if !rollback {
return &reply.OkReply{} return &reply.OkReply{}
} }
return errReply return errReply
} }
func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
argCount := len(args) - 1 argCount := len(args) - 1
if argCount%2 != 0 || argCount < 1 { if argCount%2 != 0 || argCount < 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
} }
var peer string var peer string
size := argCount / 2 size := argCount / 2
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
key := string(args[2*i]) key := string(args[2*i])
currentPeer := cluster.peerPicker.Get(key) currentPeer := cluster.peerPicker.Get(key)
if peer == "" { if peer == "" {
peer = currentPeer peer = currentPeer
} else { } else {
if peer != currentPeer { if peer != currentPeer {
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode") return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
} }
} }
} }
return cluster.Relay(peer, c, args) return cluster.Relay(peer, c, args)
} }

View File

@@ -1,40 +1,39 @@
package cluster package cluster
import ( import (
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
) )
// TODO: support multiplex slots // TODO: support multiplex slots
func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command") return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
} }
src := string(args[1]) src := string(args[1])
dest := string(args[2]) dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src) srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest) destPeer := cluster.peerPicker.Get(dest)
if srcPeer != destPeer { if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode") return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
} }
return cluster.Relay(srcPeer, c, args) return cluster.Relay(srcPeer, c, args)
} }
func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command") return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
} }
src := string(args[1]) src := string(args[1])
dest := string(args[2]) dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src) srcPeer := cluster.peerPicker.Get(src)
destPeer := cluster.peerPicker.Get(dest) destPeer := cluster.peerPicker.Get(dest)
if srcPeer != destPeer { if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode") return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
} }
return cluster.Relay(srcPeer, c, args) return cluster.Relay(srcPeer, c, args)
} }

View File

@@ -3,106 +3,106 @@ package cluster
import "github.com/HDT3213/godis/src/interface/redis" import "github.com/HDT3213/godis/src/interface/redis"
func MakeRouter() map[string]CmdFunc { func MakeRouter() map[string]CmdFunc {
routerMap := make(map[string]CmdFunc) routerMap := make(map[string]CmdFunc)
routerMap["ping"] = Ping routerMap["ping"] = Ping
routerMap["commit"] = Commit routerMap["commit"] = Commit
routerMap["rollback"] = Rollback routerMap["rollback"] = Rollback
routerMap["del"] = Del routerMap["del"] = Del
routerMap["preparedel"] = PrepareDel routerMap["preparedel"] = PrepareDel
routerMap["preparemset"] = PrepareMSet routerMap["preparemset"] = PrepareMSet
routerMap["expire"] = defaultFunc routerMap["expire"] = defaultFunc
routerMap["expireat"] = defaultFunc routerMap["expireat"] = defaultFunc
routerMap["pexpire"] = defaultFunc routerMap["pexpire"] = defaultFunc
routerMap["pexpireat"] = defaultFunc routerMap["pexpireat"] = defaultFunc
routerMap["ttl"] = defaultFunc routerMap["ttl"] = defaultFunc
routerMap["pttl"] = defaultFunc routerMap["pttl"] = defaultFunc
routerMap["persist"] = defaultFunc routerMap["persist"] = defaultFunc
routerMap["exists"] = defaultFunc routerMap["exists"] = defaultFunc
routerMap["type"] = defaultFunc routerMap["type"] = defaultFunc
routerMap["rename"] = Rename routerMap["rename"] = Rename
routerMap["renamenx"] = RenameNx routerMap["renamenx"] = RenameNx
routerMap["set"] = defaultFunc routerMap["set"] = defaultFunc
routerMap["setnx"] = defaultFunc routerMap["setnx"] = defaultFunc
routerMap["setex"] = defaultFunc routerMap["setex"] = defaultFunc
routerMap["psetex"] = defaultFunc routerMap["psetex"] = defaultFunc
routerMap["mset"] = MSet routerMap["mset"] = MSet
routerMap["mget"] = MGet routerMap["mget"] = MGet
routerMap["msetnx"] = MSetNX routerMap["msetnx"] = MSetNX
routerMap["get"] = defaultFunc routerMap["get"] = defaultFunc
routerMap["getset"] = defaultFunc routerMap["getset"] = defaultFunc
routerMap["incr"] = defaultFunc routerMap["incr"] = defaultFunc
routerMap["incrby"] = defaultFunc routerMap["incrby"] = defaultFunc
routerMap["incrbyfloat"] = defaultFunc routerMap["incrbyfloat"] = defaultFunc
routerMap["decr"] = defaultFunc routerMap["decr"] = defaultFunc
routerMap["decrby"] = defaultFunc routerMap["decrby"] = defaultFunc
routerMap["lpush"] = defaultFunc routerMap["lpush"] = defaultFunc
routerMap["lpushx"] = defaultFunc routerMap["lpushx"] = defaultFunc
routerMap["rpush"] = defaultFunc routerMap["rpush"] = defaultFunc
routerMap["rpushx"] = defaultFunc routerMap["rpushx"] = defaultFunc
routerMap["lpop"] = defaultFunc routerMap["lpop"] = defaultFunc
routerMap["rpop"] = defaultFunc routerMap["rpop"] = defaultFunc
//routerMap["rpoplpush"] = RPopLPush //routerMap["rpoplpush"] = RPopLPush
routerMap["lrem"] = defaultFunc routerMap["lrem"] = defaultFunc
routerMap["llen"] = defaultFunc routerMap["llen"] = defaultFunc
routerMap["lindex"] = defaultFunc routerMap["lindex"] = defaultFunc
routerMap["lset"] = defaultFunc routerMap["lset"] = defaultFunc
routerMap["lrange"] = defaultFunc routerMap["lrange"] = defaultFunc
routerMap["hset"] = defaultFunc routerMap["hset"] = defaultFunc
routerMap["hsetnx"] = defaultFunc routerMap["hsetnx"] = defaultFunc
routerMap["hget"] = defaultFunc routerMap["hget"] = defaultFunc
routerMap["hexists"] = defaultFunc routerMap["hexists"] = defaultFunc
routerMap["hdel"] = defaultFunc routerMap["hdel"] = defaultFunc
routerMap["hlen"] = defaultFunc routerMap["hlen"] = defaultFunc
routerMap["hmget"] = defaultFunc routerMap["hmget"] = defaultFunc
routerMap["hmset"] = defaultFunc routerMap["hmset"] = defaultFunc
routerMap["hkeys"] = defaultFunc routerMap["hkeys"] = defaultFunc
routerMap["hvals"] = defaultFunc routerMap["hvals"] = defaultFunc
routerMap["hgetall"] = defaultFunc routerMap["hgetall"] = defaultFunc
routerMap["hincrby"] = defaultFunc routerMap["hincrby"] = defaultFunc
routerMap["hincrbyfloat"] = defaultFunc routerMap["hincrbyfloat"] = defaultFunc
routerMap["sadd"] = defaultFunc routerMap["sadd"] = defaultFunc
routerMap["sismember"] = defaultFunc routerMap["sismember"] = defaultFunc
routerMap["srem"] = defaultFunc routerMap["srem"] = defaultFunc
routerMap["scard"] = defaultFunc routerMap["scard"] = defaultFunc
routerMap["smembers"] = defaultFunc routerMap["smembers"] = defaultFunc
routerMap["sinter"] = defaultFunc routerMap["sinter"] = defaultFunc
routerMap["sinterstore"] = defaultFunc routerMap["sinterstore"] = defaultFunc
routerMap["sunion"] = defaultFunc routerMap["sunion"] = defaultFunc
routerMap["sunionstore"] = defaultFunc routerMap["sunionstore"] = defaultFunc
routerMap["sdiff"] = defaultFunc routerMap["sdiff"] = defaultFunc
routerMap["sdiffstore"] = defaultFunc routerMap["sdiffstore"] = defaultFunc
routerMap["srandmember"] = defaultFunc routerMap["srandmember"] = defaultFunc
routerMap["zadd"] = defaultFunc routerMap["zadd"] = defaultFunc
routerMap["zscore"] = defaultFunc routerMap["zscore"] = defaultFunc
routerMap["zincrby"] = defaultFunc routerMap["zincrby"] = defaultFunc
routerMap["zrank"] = defaultFunc routerMap["zrank"] = defaultFunc
routerMap["zcount"] = defaultFunc routerMap["zcount"] = defaultFunc
routerMap["zrevrank"] = defaultFunc routerMap["zrevrank"] = defaultFunc
routerMap["zcard"] = defaultFunc routerMap["zcard"] = defaultFunc
routerMap["zrange"] = defaultFunc routerMap["zrange"] = defaultFunc
routerMap["zrevrange"] = defaultFunc routerMap["zrevrange"] = defaultFunc
routerMap["zrangebyscore"] = defaultFunc routerMap["zrangebyscore"] = defaultFunc
routerMap["zrevrangebyscore"] = defaultFunc routerMap["zrevrangebyscore"] = defaultFunc
routerMap["zrem"] = defaultFunc routerMap["zrem"] = defaultFunc
routerMap["zremrangebyscore"] = defaultFunc routerMap["zremrangebyscore"] = defaultFunc
routerMap["zremrangebyrank"] = defaultFunc routerMap["zremrangebyrank"] = defaultFunc
//routerMap["flushdb"] = FlushDB //routerMap["flushdb"] = FlushDB
//routerMap["flushall"] = FlushAll //routerMap["flushall"] = FlushAll
//routerMap["keys"] = Keys //routerMap["keys"] = Keys
return routerMap return routerMap
} }
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
key := string(args[1]) key := string(args[1])
peer := cluster.peerPicker.Get(key) peer := cluster.peerPicker.Get(key)
return cluster.Relay(peer, c, args) return cluster.Relay(peer, c, args)
} }

View File

@@ -1,188 +1,188 @@
package cluster package cluster
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/HDT3213/godis/src/db" "github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/marshal/gob" "github.com/HDT3213/godis/src/lib/marshal/gob"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"strconv" "strconv"
"strings" "strings"
"time" "time"
) )
type Transaction struct { type Transaction struct {
id string // transaction id id string // transaction id
args [][]byte // cmd args args [][]byte // cmd args
cluster *Cluster cluster *Cluster
conn redis.Connection conn redis.Connection
keys []string // related keys keys []string // related keys
undoLog map[string][]byte // store data for undoLog undoLog map[string][]byte // store data for undoLog
lockUntil time.Time lockUntil time.Time
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
status int8 status int8
} }
const ( const (
maxLockTime = 3 * time.Second maxLockTime = 3 * time.Second
CreatedStatus = 0 CreatedStatus = 0
PreparedStatus = 1 PreparedStatus = 1
CommitedStatus = 2 CommitedStatus = 2
RollbackedStatus = 3 RollbackedStatus = 3
) )
func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction { func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction {
return &Transaction{ return &Transaction{
id: id, id: id,
args: args, args: args,
cluster: cluster, cluster: cluster,
conn: c, conn: c,
keys: keys, keys: keys,
status: CreatedStatus, status: CreatedStatus,
} }
} }
// t should contains Keys field // t should contains Keys field
func (tx *Transaction) prepare() error { func (tx *Transaction) prepare() error {
// lock keys // lock keys
tx.cluster.db.Locks(tx.keys...) tx.cluster.db.Locks(tx.keys...)
// use context to manage // use context to manage
//tx.lockUntil = time.Now().Add(maxLockTime) //tx.lockUntil = time.Now().Add(maxLockTime)
//ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil) //ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil)
//tx.ctx = ctx //tx.ctx = ctx
//tx.cancel = cancel //tx.cancel = cancel
// build undoLog // build undoLog
tx.undoLog = make(map[string][]byte) tx.undoLog = make(map[string][]byte)
for _, key := range tx.keys { for _, key := range tx.keys {
entity, ok := tx.cluster.db.Get(key) entity, ok := tx.cluster.db.Get(key)
if ok { if ok {
blob, err := gob.Marshal(entity) blob, err := gob.Marshal(entity)
if err != nil { if err != nil {
return err return err
} }
tx.undoLog[key] = blob tx.undoLog[key] = blob
} else { } else {
tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback
} }
} }
tx.status = PreparedStatus tx.status = PreparedStatus
return nil return nil
} }
func (tx *Transaction) rollback() error { func (tx *Transaction) rollback() error {
for key, blob := range tx.undoLog { for key, blob := range tx.undoLog {
if len(blob) > 0 { if len(blob) > 0 {
entity := &db.DataEntity{} entity := &db.DataEntity{}
err := gob.UnMarshal(blob, entity) err := gob.UnMarshal(blob, entity)
if err != nil { if err != nil {
return err return err
} }
tx.cluster.db.Put(key, entity) tx.cluster.db.Put(key, entity)
} else { } else {
tx.cluster.db.Remove(key) tx.cluster.db.Remove(key)
} }
} }
if tx.status != CommitedStatus { if tx.status != CommitedStatus {
tx.cluster.db.UnLocks(tx.keys...) tx.cluster.db.UnLocks(tx.keys...)
} }
tx.status = RollbackedStatus tx.status = RollbackedStatus
return nil return nil
} }
// rollback local transaction // rollback local transaction
func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command") return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
} }
txId := string(args[1]) txId := string(args[1])
raw, ok := cluster.transactions.Get(txId) raw, ok := cluster.transactions.Get(txId)
if !ok { if !ok {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
tx, _ := raw.(*Transaction) tx, _ := raw.(*Transaction)
err := tx.rollback() err := tx.rollback()
if err != nil { if err != nil {
return reply.MakeErrReply(err.Error()) return reply.MakeErrReply(err.Error())
} }
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
// commit local transaction as a worker // commit local transaction as a worker
func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command") return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
} }
txId := string(args[1]) txId := string(args[1])
raw, ok := cluster.transactions.Get(txId) raw, ok := cluster.transactions.Get(txId)
if !ok { if !ok {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
tx, _ := raw.(*Transaction) tx, _ := raw.(*Transaction)
// finish transaction // finish transaction
defer func() { defer func() {
cluster.db.UnLocks(tx.keys...) cluster.db.UnLocks(tx.keys...)
tx.status = CommitedStatus tx.status = CommitedStatus
//cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit //cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit
}() }()
cmd := strings.ToLower(string(tx.args[0])) cmd := strings.ToLower(string(tx.args[0]))
var result redis.Reply var result redis.Reply
if cmd == "del" { if cmd == "del" {
result = CommitDel(cluster, c, tx) result = CommitDel(cluster, c, tx)
} else if cmd == "mset" { } else if cmd == "mset" {
result = CommitMSet(cluster, c, tx) result = CommitMSet(cluster, c, tx)
} }
if reply.IsErrorReply(result) { if reply.IsErrorReply(result) {
// failed // failed
err2 := tx.rollback() err2 := tx.rollback()
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result)) return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
} }
return result return result
} }
// request all node commit transaction as leader // request all node commit transaction as leader
func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) { func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) {
var errReply reply.ErrorReply var errReply reply.ErrorReply
txIdStr := strconv.FormatInt(txId, 10) txIdStr := strconv.FormatInt(txId, 10)
respList := make([]redis.Reply, 0, len(peers)) respList := make([]redis.Reply, 0, len(peers))
for peer := range peers { for peer := range peers {
var resp redis.Reply var resp redis.Reply
if peer == cluster.self { if peer == cluster.self {
resp = Commit(cluster, c, makeArgs("commit", txIdStr)) resp = Commit(cluster, c, makeArgs("commit", txIdStr))
} else { } else {
resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr)) resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr))
} }
if reply.IsErrorReply(resp) { if reply.IsErrorReply(resp) {
errReply = resp.(reply.ErrorReply) errReply = resp.(reply.ErrorReply)
break break
} }
respList = append(respList, resp) respList = append(respList, resp)
} }
if errReply != nil { if errReply != nil {
RequestRollback(cluster, c, txId, peers) RequestRollback(cluster, c, txId, peers)
return nil, errReply return nil, errReply
} }
return respList, nil return respList, nil
} }
// request all node rollback transaction as leader // request all node rollback transaction as leader
func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) { func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) {
txIdStr := strconv.FormatInt(txId, 10) txIdStr := strconv.FormatInt(txId, 10)
for peer := range peers { for peer := range peers {
if peer == cluster.self { if peer == cluster.self {
Rollback(cluster, c, makeArgs("rollback", txIdStr)) Rollback(cluster, c, makeArgs("rollback", txIdStr))
} else { } else {
cluster.Relay(peer, c, makeArgs("rollback", txIdStr)) cluster.Relay(peer, c, makeArgs("rollback", txIdStr))
} }
} }
} }

View File

@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/logger"
RedisServer "github.com/HDT3213/godis/src/redis/server" RedisServer "github.com/HDT3213/godis/src/redis/server"
"github.com/HDT3213/godis/src/tcp" "github.com/HDT3213/godis/src/tcp"
"os" "os"
) )
@@ -23,6 +23,6 @@ func main() {
}) })
tcp.ListenAndServe(&tcp.Config{ tcp.ListenAndServe(&tcp.Config{
Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port), Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port),
}, RedisServer.MakeHandler()) }, RedisServer.MakeHandler())
} }

View File

@@ -23,12 +23,12 @@ type PropertyHolder struct {
var Properties *PropertyHolder var Properties *PropertyHolder
func init() { func init() {
// default config // default config
Properties = &PropertyHolder{ Properties = &PropertyHolder{
Bind: "127.0.0.1", Bind: "127.0.0.1",
Port: 6379, Port: 6379,
AppendOnly: false, AppendOnly: false,
} }
} }
func LoadConfig(configFilename string) *PropertyHolder { func LoadConfig(configFilename string) *PropertyHolder {

View File

@@ -1,16 +1,16 @@
package dict package dict
type Consumer func(key string, val interface{})bool type Consumer func(key string, val interface{}) bool
type Dict interface { type Dict interface {
Get(key string) (val interface{}, exists bool) Get(key string) (val interface{}, exists bool)
Len() int Len() int
Put(key string, val interface{}) (result int) Put(key string, val interface{}) (result int)
PutIfAbsent(key string, val interface{}) (result int) PutIfAbsent(key string, val interface{}) (result int)
PutIfExists(key string, val interface{}) (result int) PutIfExists(key string, val interface{}) (result int)
Remove(key string) (result int) Remove(key string) (result int)
ForEach(consumer Consumer) ForEach(consumer Consumer)
Keys() []string Keys() []string
RandomKeys(limit int) []string RandomKeys(limit int) []string
RandomDistinctKeys(limit int) []string RandomDistinctKeys(limit int) []string
} }

View File

@@ -1,244 +1,244 @@
package dict package dict
import ( import (
"strconv" "strconv"
"sync" "sync"
"testing" "testing"
) )
func TestPut(t *testing.T) { func TestPut(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
count := 100 count := 100
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(count) wg.Add(count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
go func(i int) { go func(i int) {
// insert // insert
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
ret := d.Put(key, i) ret := d.Put(key, i)
if ret != 1 { // insert 1 if ret != 1 { // insert 1
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key) t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
} }
val, ok := d.Get(key) val, ok := d.Get(key)
if ok { if ok {
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != i { if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key) t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
} }
} else { } else {
_, ok := d.Get(key) _, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
} }
wg.Done() wg.Done()
}(i) }(i)
} }
wg.Wait() wg.Wait()
} }
func TestPutIfAbsent(t *testing.T) { func TestPutIfAbsent(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
count := 100 count := 100
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(count) wg.Add(count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
go func(i int) { go func(i int) {
// insert // insert
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
ret := d.PutIfAbsent(key, i) ret := d.PutIfAbsent(key, i)
if ret != 1 { // insert 1 if ret != 1 { // insert 1
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key) t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
} }
val, ok := d.Get(key) val, ok := d.Get(key)
if ok { if ok {
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != i { if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) +
", key: " + key) ", key: " + key)
} }
} else { } else {
_, ok := d.Get(key) _, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
} }
// update // update
ret = d.PutIfAbsent(key, i * 10) ret = d.PutIfAbsent(key, i*10)
if ret != 0 { // no update if ret != 0 { // no update
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
} }
val, ok = d.Get(key) val, ok = d.Get(key)
if ok { if ok {
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != i { if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key) t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
} }
} else { } else {
t.Error("put test failed: expected true, actual: false, key: " + key) t.Error("put test failed: expected true, actual: false, key: " + key)
} }
wg.Done() wg.Done()
}(i) }(i)
} }
wg.Wait() wg.Wait()
} }
func TestPutIfExists(t *testing.T) { func TestPutIfExists(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
count := 100 count := 100
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(count) wg.Add(count)
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
go func(i int) { go func(i int) {
// insert // insert
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
// insert // insert
ret := d.PutIfExists(key, i) ret := d.PutIfExists(key, i)
if ret != 0 { // insert if ret != 0 { // insert
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
} }
d.Put(key, i) d.Put(key, i)
ret = d.PutIfExists(key, 10 * i) ret = d.PutIfExists(key, 10*i)
val, ok := d.Get(key) val, ok := d.Get(key)
if ok { if ok {
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != 10 * i { if intVal != 10*i {
t.Error("put test failed: expected " + strconv.Itoa(10 * i) + ", actual: " + strconv.Itoa(intVal)) t.Error("put test failed: expected " + strconv.Itoa(10*i) + ", actual: " + strconv.Itoa(intVal))
} }
} else { } else {
_, ok := d.Get(key) _, ok := d.Get(key)
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok)) t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
} }
wg.Done() wg.Done()
}(i) }(i)
} }
wg.Wait() wg.Wait()
} }
func TestRemove(t *testing.T) { func TestRemove(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
// remove head node // remove head node
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
// insert // insert
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
d.Put(key, i) d.Put(key, i)
} }
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
val, ok := d.Get(key) val, ok := d.Get(key)
if ok { if ok {
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != i { if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
} }
} else { } else {
t.Error("put test failed: expected true, actual: false") t.Error("put test failed: expected true, actual: false")
} }
ret := d.Remove(key) ret := d.Remove(key)
if ret != 1 { if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key) t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key)
} }
_, ok = d.Get(key) _, ok = d.Get(key)
if ok { if ok {
t.Error("remove test failed: expected true, actual false") t.Error("remove test failed: expected true, actual false")
} }
ret = d.Remove(key) ret = d.Remove(key)
if ret != 0 { if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
} }
} }
// remove tail node // remove tail node
d = MakeConcurrent(0) d = MakeConcurrent(0)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
// insert // insert
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
d.Put(key, i) d.Put(key, i)
} }
for i := 9; i >= 0; i-- { for i := 9; i >= 0; i-- {
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
val, ok := d.Get(key) val, ok := d.Get(key)
if ok { if ok {
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != i { if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
} }
} else { } else {
t.Error("put test failed: expected true, actual: false") t.Error("put test failed: expected true, actual: false")
} }
ret := d.Remove(key) ret := d.Remove(key)
if ret != 1 { if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
} }
_, ok = d.Get(key) _, ok = d.Get(key)
if ok { if ok {
t.Error("remove test failed: expected true, actual false") t.Error("remove test failed: expected true, actual false")
} }
ret = d.Remove(key) ret = d.Remove(key)
if ret != 0 { if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
} }
} }
// remove middle node // remove middle node
d = MakeConcurrent(0) d = MakeConcurrent(0)
d.Put("head", 0) d.Put("head", 0)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
// insert // insert
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
d.Put(key, i) d.Put(key, i)
} }
d.Put("tail", 0) d.Put("tail", 0)
for i := 9; i >= 0; i-- { for i := 9; i >= 0; i-- {
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
val, ok := d.Get(key) val, ok := d.Get(key)
if ok { if ok {
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != i { if intVal != i {
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
} }
} else { } else {
t.Error("put test failed: expected true, actual: false") t.Error("put test failed: expected true, actual: false")
} }
ret := d.Remove(key) ret := d.Remove(key)
if ret != 1 { if ret != 1 {
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret)) t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
} }
_, ok = d.Get(key) _, ok = d.Get(key)
if ok { if ok {
t.Error("remove test failed: expected true, actual false") t.Error("remove test failed: expected true, actual false")
} }
ret = d.Remove(key) ret = d.Remove(key)
if ret != 0 { if ret != 0 {
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret)) t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
} }
} }
} }
func TestForEach(t *testing.T) { func TestForEach(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
size := 100 size := 100
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
// insert // insert
key := "k" + strconv.Itoa(i) key := "k" + strconv.Itoa(i)
d.Put(key, i) d.Put(key, i)
} }
i := 0 i := 0
d.ForEach(func(key string, value interface{})bool { d.ForEach(func(key string, value interface{}) bool {
intVal, _ := value.(int) intVal, _ := value.(int)
expectedKey := "k" + strconv.Itoa(intVal) expectedKey := "k" + strconv.Itoa(intVal)
if key != expectedKey { if key != expectedKey {
t.Error("remove test failed: expected " + expectedKey + ", actual: " + key) t.Error("remove test failed: expected " + expectedKey + ", actual: " + key)
} }
i++ i++
return true return true
}) })
if i != size { if i != size {
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i)) t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
} }
} }

View File

@@ -1,108 +1,108 @@
package dict package dict
type SimpleDict struct { type SimpleDict struct {
m map[string]interface{} m map[string]interface{}
} }
func MakeSimple() *SimpleDict { func MakeSimple() *SimpleDict {
return &SimpleDict{ return &SimpleDict{
m: make(map[string]interface{}), m: make(map[string]interface{}),
} }
} }
func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) { func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) {
val, ok := dict.m[key] val, ok := dict.m[key]
return val, ok return val, ok
} }
func (dict *SimpleDict) Len() int { func (dict *SimpleDict) Len() int {
if dict.m == nil { if dict.m == nil {
panic("m is nil") panic("m is nil")
} }
return len(dict.m) return len(dict.m)
} }
func (dict *SimpleDict) Put(key string, val interface{}) (result int) { func (dict *SimpleDict) Put(key string, val interface{}) (result int) {
_, existed := dict.m[key] _, existed := dict.m[key]
dict.m[key] = val dict.m[key] = val
if existed { if existed {
return 0 return 0
} else { } else {
return 1 return 1
} }
} }
func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) { func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) {
_, existed := dict.m[key] _, existed := dict.m[key]
if existed { if existed {
return 0 return 0
} else { } else {
dict.m[key] = val dict.m[key] = val
return 1 return 1
} }
} }
func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) { func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) {
_, existed := dict.m[key] _, existed := dict.m[key]
if existed { if existed {
dict.m[key] = val dict.m[key] = val
return 1 return 1
} else { } else {
return 0 return 0
} }
} }
func (dict *SimpleDict) Remove(key string) (result int) { func (dict *SimpleDict) Remove(key string) (result int) {
_, existed := dict.m[key] _, existed := dict.m[key]
delete(dict.m, key) delete(dict.m, key)
if existed { if existed {
return 1 return 1
} else { } else {
return 0 return 0
} }
} }
func (dict *SimpleDict) Keys() []string { func (dict *SimpleDict) Keys() []string {
result := make([]string, len(dict.m)) result := make([]string, len(dict.m))
i := 0 i := 0
for k := range dict.m { for k := range dict.m {
result[i] = k result[i] = k
} }
return result return result
} }
func (dict *SimpleDict) ForEach(consumer Consumer) { func (dict *SimpleDict) ForEach(consumer Consumer) {
for k, v := range dict.m { for k, v := range dict.m {
if !consumer(k, v) { if !consumer(k, v) {
break break
} }
} }
} }
func (dict *SimpleDict) RandomKeys(limit int) []string { func (dict *SimpleDict) RandomKeys(limit int) []string {
result := make([]string, limit) result := make([]string, limit)
for i := 0; i < limit; i++ { for i := 0; i < limit; i++ {
for k := range dict.m { for k := range dict.m {
result[i] = k result[i] = k
break break
} }
} }
return result return result
} }
func (dict *SimpleDict) RandomDistinctKeys(limit int) []string { func (dict *SimpleDict) RandomDistinctKeys(limit int) []string {
size := limit size := limit
if size > len(dict.m) { if size > len(dict.m) {
size = len(dict.m) size = len(dict.m)
} }
result := make([]string, size) result := make([]string, size)
i := 0 i := 0
for k := range dict.m { for k := range dict.m {
if i == limit { if i == limit {
break break
} }
result[i] = k result[i] = k
i++ i++
} }
return result return result
} }

View File

@@ -3,324 +3,324 @@ package list
import "github.com/HDT3213/godis/src/datastruct/utils" import "github.com/HDT3213/godis/src/datastruct/utils"
type LinkedList struct { type LinkedList struct {
first *node first *node
last *node last *node
size int size int
} }
type node struct { type node struct {
val interface{} val interface{}
prev *node prev *node
next * node next *node
} }
func (list *LinkedList)Add(val interface{}) { func (list *LinkedList) Add(val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
n := &node{ n := &node{
val: val, val: val,
} }
if list.last == nil { if list.last == nil {
// empty list // empty list
list.first = n list.first = n
list.last = n list.last = n
} else { } else {
n.prev = list.last n.prev = list.last
list.last.next = n list.last.next = n
list.last = n list.last = n
} }
list.size++ list.size++
} }
func (list *LinkedList)find(index int)(n *node) { func (list *LinkedList) find(index int) (n *node) {
if index < list.size / 2 { if index < list.size/2 {
n := list.first n := list.first
for i := 0; i < index; i++ { for i := 0; i < index; i++ {
n = n.next n = n.next
} }
return n return n
} else { } else {
n := list.last n := list.last
for i := list.size - 1; i > index; i-- { for i := list.size - 1; i > index; i-- {
n = n.prev n = n.prev
} }
return n return n
} }
} }
func (list *LinkedList)Get(index int)(val interface{}) { func (list *LinkedList) Get(index int) (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
if index < 0 || index >= list.size { if index < 0 || index >= list.size {
panic("index out of bound") panic("index out of bound")
} }
return list.find(index).val return list.find(index).val
} }
func (list *LinkedList)Set(index int, val interface{}) { func (list *LinkedList) Set(index int, val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
if index < 0 || index > list.size { if index < 0 || index > list.size {
panic("index out of bound") panic("index out of bound")
} }
n := list.find(index) n := list.find(index)
n.val = val n.val = val
} }
func (list *LinkedList)Insert(index int, val interface{}) { func (list *LinkedList) Insert(index int, val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
if index < 0 || index > list.size { if index < 0 || index > list.size {
panic("index out of bound") panic("index out of bound")
} }
if index == list.size { if index == list.size {
list.Add(val) list.Add(val)
return return
} else { } else {
// list is not empty // list is not empty
pivot := list.find(index) pivot := list.find(index)
n := &node{ n := &node{
val: val, val: val,
prev: pivot.prev, prev: pivot.prev,
next: pivot, next: pivot,
} }
if pivot.prev == nil { if pivot.prev == nil {
list.first = n list.first = n
} else { } else {
pivot.prev.next = n pivot.prev.next = n
} }
pivot.prev = n pivot.prev = n
list.size++ list.size++
} }
} }
func (list *LinkedList)removeNode(n *node) { func (list *LinkedList) removeNode(n *node) {
if n.prev == nil { if n.prev == nil {
list.first = n.next list.first = n.next
} else { } else {
n.prev.next = n.next n.prev.next = n.next
} }
if n.next == nil { if n.next == nil {
list.last = n.prev list.last = n.prev
} else { } else {
n.next.prev = n.prev n.next.prev = n.prev
} }
// for gc // for gc
n.prev = nil n.prev = nil
n.next = nil n.next = nil
list.size-- list.size--
} }
func (list *LinkedList)Remove(index int)(val interface{}) { func (list *LinkedList) Remove(index int) (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
if index < 0 || index >= list.size { if index < 0 || index >= list.size {
panic("index out of bound") panic("index out of bound")
} }
n := list.find(index) n := list.find(index)
list.removeNode(n) list.removeNode(n)
return n.val return n.val
} }
func (list *LinkedList)RemoveLast()(val interface{}) { func (list *LinkedList) RemoveLast() (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
if list.last == nil { if list.last == nil {
// empty list // empty list
return nil return nil
} }
n := list.last n := list.last
list.removeNode(n) list.removeNode(n)
return n.val return n.val
} }
func (list *LinkedList)RemoveAllByVal(val interface{})int { func (list *LinkedList) RemoveAllByVal(val interface{}) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
n := list.first n := list.first
removed := 0 removed := 0
for n != nil { for n != nil {
var toRemoveNode *node var toRemoveNode *node
if utils.Equals(n.val, val) { if utils.Equals(n.val, val) {
toRemoveNode = n toRemoveNode = n
} }
if n.next == nil { if n.next == nil {
if toRemoveNode != nil { if toRemoveNode != nil {
removed++ removed++
list.removeNode(toRemoveNode) list.removeNode(toRemoveNode)
} }
break break
} else { } else {
n = n.next n = n.next
} }
if toRemoveNode != nil { if toRemoveNode != nil {
removed++ removed++
list.removeNode(toRemoveNode) list.removeNode(toRemoveNode)
} }
} }
return removed return removed
} }
/** /**
* remove at most `count` values of the specified value in this list * remove at most `count` values of the specified value in this list
* scan from left to right * scan from left to right
*/ */
func (list *LinkedList) RemoveByVal(val interface{}, count int)int { func (list *LinkedList) RemoveByVal(val interface{}, count int) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
n := list.first n := list.first
removed := 0 removed := 0
for n != nil { for n != nil {
var toRemoveNode *node var toRemoveNode *node
if utils.Equals(n.val, val) { if utils.Equals(n.val, val) {
toRemoveNode = n toRemoveNode = n
} }
if n.next == nil { if n.next == nil {
if toRemoveNode != nil { if toRemoveNode != nil {
removed++ removed++
list.removeNode(toRemoveNode) list.removeNode(toRemoveNode)
} }
break break
} else { } else {
n = n.next n = n.next
} }
if toRemoveNode != nil { if toRemoveNode != nil {
removed++ removed++
list.removeNode(toRemoveNode) list.removeNode(toRemoveNode)
} }
if removed == count { if removed == count {
break break
} }
} }
return removed return removed
} }
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int)int { func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
n := list.last n := list.last
removed := 0 removed := 0
for n != nil { for n != nil {
var toRemoveNode *node var toRemoveNode *node
if utils.Equals(n.val, val) { if utils.Equals(n.val, val) {
toRemoveNode = n toRemoveNode = n
} }
if n.prev == nil { if n.prev == nil {
if toRemoveNode != nil { if toRemoveNode != nil {
removed++ removed++
list.removeNode(toRemoveNode) list.removeNode(toRemoveNode)
} }
break break
} else { } else {
n = n.prev n = n.prev
} }
if toRemoveNode != nil { if toRemoveNode != nil {
removed++ removed++
list.removeNode(toRemoveNode) list.removeNode(toRemoveNode)
} }
if removed == count { if removed == count {
break break
} }
} }
return removed return removed
} }
func (list *LinkedList)Len()int { func (list *LinkedList) Len() int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
return list.size return list.size
} }
func (list *LinkedList)ForEach(consumer func(int, interface{})bool) { func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
n := list.first n := list.first
i := 0 i := 0
for n != nil { for n != nil {
goNext := consumer(i, n.val) goNext := consumer(i, n.val)
if !goNext || n.next == nil { if !goNext || n.next == nil {
break break
} else { } else {
i++ i++
n = n.next n = n.next
} }
} }
} }
func (list *LinkedList)Contains(val interface{})bool { func (list *LinkedList) Contains(val interface{}) bool {
contains := false contains := false
list.ForEach(func(i int, actual interface{}) bool { list.ForEach(func(i int, actual interface{}) bool {
if actual == val { if actual == val {
contains = true contains = true
return false return false
} }
return true return true
}) })
return contains return contains
} }
func (list *LinkedList)Range(start int, stop int)[]interface{} { func (list *LinkedList) Range(start int, stop int) []interface{} {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
if start < 0 || start >= list.size { if start < 0 || start >= list.size {
panic("`start` out of range") panic("`start` out of range")
} }
if stop < start || stop > list.size { if stop < start || stop > list.size {
panic("`stop` out of range") panic("`stop` out of range")
} }
sliceSize := stop - start sliceSize := stop - start
slice := make([]interface{}, sliceSize) slice := make([]interface{}, sliceSize)
n := list.first n := list.first
i := 0 i := 0
sliceIndex := 0 sliceIndex := 0
for n != nil { for n != nil {
if i >= start && i < stop { if i >= start && i < stop {
slice[sliceIndex] = n.val slice[sliceIndex] = n.val
sliceIndex++ sliceIndex++
} else if i >= stop { } else if i >= stop {
break break
} }
if n.next == nil { if n.next == nil {
break break
} else { } else {
i++ i++
n = n.next n = n.next
} }
} }
return slice return slice
} }
func Make(vals ...interface{}) *LinkedList { func Make(vals ...interface{}) *LinkedList {
list := LinkedList{} list := LinkedList{}
for _, v := range vals { for _, v := range vals {
list.Add(v) list.Add(v)
} }
return &list return &list
} }
func MakeBytesList(vals ...[]byte) *LinkedList { func MakeBytesList(vals ...[]byte) *LinkedList {
list := LinkedList{} list := LinkedList{}
for _, v := range vals { for _, v := range vals {
list.Add(v) list.Add(v)
} }
return &list return &list
} }

View File

@@ -1,215 +1,215 @@
package list package list
import ( import (
"testing" "strconv"
"strconv" "strings"
"strings" "testing"
) )
func ToString(list *LinkedList) string { func ToString(list *LinkedList) string {
arr := make([]string, list.size) arr := make([]string, list.size)
list.ForEach(func(i int, v interface{}) bool { list.ForEach(func(i int, v interface{}) bool {
integer, _ := v.(int) integer, _ := v.(int)
arr[i] = strconv.Itoa(integer) arr[i] = strconv.Itoa(integer)
return true return true
}) })
return "[" + strings.Join(arr, ", ") + "]" return "[" + strings.Join(arr, ", ") + "]"
} }
func TestAdd(t *testing.T) { func TestAdd(t *testing.T) {
list := Make() list := Make()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Add(i) list.Add(i)
} }
list.ForEach(func(i int, v interface{}) bool { list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int) intVal, _ := v.(int)
if intVal != i { if intVal != i {
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
} }
return true return true
}) })
} }
func TestGet(t *testing.T) { func TestGet(t *testing.T) {
list := Make() list := Make()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Add(i) list.Add(i)
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
v := list.Get(i) v := list.Get(i)
k, _ := v.(int) k, _ := v.(int)
if i != k { if i != k {
t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k)) t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k))
} }
} }
} }
func TestRemove(t *testing.T) { func TestRemove(t *testing.T) {
list := Make() list := Make()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Add(i) list.Add(i)
} }
for i := 9; i >= 0; i-- { for i := 9; i >= 0; i-- {
list.Remove(i) list.Remove(i)
if i != list.Len() { if i != list.Len() {
t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len())) t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len()))
} }
list.ForEach(func(i int, v interface{}) bool { list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int) intVal, _ := v.(int)
if intVal != i { if intVal != i {
t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
} }
return true return true
}) })
} }
} }
func TestRemoveVal(t *testing.T) { func TestRemoveVal(t *testing.T) {
list := Make() list := Make()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Add(i) list.Add(i)
list.Add(i) list.Add(i)
} }
for index := 0; index < list.Len(); index++ { for index := 0; index < list.Len(); index++ {
list.RemoveAllByVal(index) list.RemoveAllByVal(index)
list.ForEach(func(i int, v interface{}) bool { list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int) intVal, _ := v.(int)
if intVal == index { if intVal == index {
t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i)) t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i))
} }
return true return true
}) })
} }
list = Make() list = Make()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Add(i) list.Add(i)
list.Add(i) list.Add(i)
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.RemoveByVal(i, 1) list.RemoveByVal(i, 1)
} }
list.ForEach(func(i int, v interface{}) bool { list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int) intVal, _ := v.(int)
if intVal != i { if intVal != i {
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
} }
return true return true
}) })
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.RemoveByVal(i, 1) list.RemoveByVal(i, 1)
} }
if list.Len() != 0 { if list.Len() != 0 {
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len())) t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
} }
list = Make() list = Make()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Add(i) list.Add(i)
list.Add(i) list.Add(i)
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.ReverseRemoveByVal(i, 1) list.ReverseRemoveByVal(i, 1)
} }
list.ForEach(func(i int, v interface{}) bool { list.ForEach(func(i int, v interface{}) bool {
intVal, _ := v.(int) intVal, _ := v.(int)
if intVal != i { if intVal != i {
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
} }
return true return true
}) })
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.ReverseRemoveByVal(i, 1) list.ReverseRemoveByVal(i, 1)
} }
if list.Len() != 0 { if list.Len() != 0 {
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len())) t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
} }
} }
func TestInsert(t *testing.T) { func TestInsert(t *testing.T) {
list := Make() list := Make()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Add(i) list.Add(i)
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Insert(i*2, i) list.Insert(i*2, i)
list.ForEach(func(j int, v interface{}) bool { list.ForEach(func(j int, v interface{}) bool {
var expected int var expected int
if j < (i + 1) * 2 { if j < (i+1)*2 {
if j%2 == 0 { if j%2 == 0 {
expected = j / 2 expected = j / 2
} else { } else {
expected = (j - 1) / 2 expected = (j - 1) / 2
} }
} else { } else {
expected = j - i - 1 expected = j - i - 1
} }
actual, _ := list.Get(j).(int) actual, _ := list.Get(j).(int)
if actual != expected { if actual != expected {
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual)) t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
} }
return true return true
}) })
for j := 0; j < list.Len(); j++ { for j := 0; j < list.Len(); j++ {
var expected int var expected int
if j < (i + 1) * 2 { if j < (i+1)*2 {
if j%2 == 0 { if j%2 == 0 {
expected = j / 2 expected = j / 2
} else { } else {
expected = (j - 1) / 2 expected = (j - 1) / 2
} }
} else { } else {
expected = j - i - 1 expected = j - i - 1
} }
actual, _ := list.Get(j).(int) actual, _ := list.Get(j).(int)
if actual != expected { if actual != expected {
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual)) t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
} }
} }
} }
} }
func TestRemoveLast(t *testing.T) { func TestRemoveLast(t *testing.T) {
list := Make() list := Make()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
list.Add(i) list.Add(i)
} }
for i := 9; i >= 0; i-- { for i := 9; i >= 0; i-- {
val := list.RemoveLast() val := list.RemoveLast()
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != i { if intVal != i {
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal)) t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
} }
} }
} }
func TestRange(t *testing.T) { func TestRange(t *testing.T) {
list := Make() list := Make()
size := 10 size := 10
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
list.Add(i) list.Add(i)
} }
for start := 0; start < size; start++ { for start := 0; start < size; start++ {
for stop := start; stop < size; stop++ { for stop := start; stop < size; stop++ {
slice := list.Range(start, stop) slice := list.Range(start, stop)
if len(slice) != stop - start { if len(slice) != stop-start {
t.Error("expected " + strconv.Itoa(stop - start) + ", get: " + strconv.Itoa(len(slice)) + t.Error("expected " + strconv.Itoa(stop-start) + ", get: " + strconv.Itoa(len(slice)) +
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]") ", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
} }
sliceIndex := 0 sliceIndex := 0
for i := start; i < stop; i++ { for i := start; i < stop; i++ {
val := slice[sliceIndex] val := slice[sliceIndex]
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != i { if intVal != i {
t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) + t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) +
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]") ", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
} }
sliceIndex++ sliceIndex++
} }
} }
} }
} }

View File

@@ -1,153 +1,152 @@
package lock package lock
import ( import (
"fmt" "fmt"
"runtime" "runtime"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"testing" "testing"
"time" "time"
) )
const ( const (
prime32 = uint32(16777619) prime32 = uint32(16777619)
) )
type Locks struct { type Locks struct {
table []*sync.RWMutex table []*sync.RWMutex
} }
func Make(tableSize int) *Locks { func Make(tableSize int) *Locks {
table := make([]*sync.RWMutex, tableSize) table := make([]*sync.RWMutex, tableSize)
for i := 0; i < tableSize; i++ { for i := 0; i < tableSize; i++ {
table[i] = &sync.RWMutex{} table[i] = &sync.RWMutex{}
} }
return &Locks{ return &Locks{
table: table, table: table,
} }
} }
func fnv32(key string) uint32 { func fnv32(key string) uint32 {
hash := uint32(2166136261) hash := uint32(2166136261)
for i := 0; i < len(key); i++ { for i := 0; i < len(key); i++ {
hash *= prime32 hash *= prime32
hash ^= uint32(key[i]) hash ^= uint32(key[i])
} }
return hash return hash
} }
func (locks *Locks) spread(hashCode uint32) uint32 { func (locks *Locks) spread(hashCode uint32) uint32 {
if locks == nil { if locks == nil {
panic("dict is nil") panic("dict is nil")
} }
tableSize := uint32(len(locks.table)) tableSize := uint32(len(locks.table))
return (tableSize - 1) & uint32(hashCode) return (tableSize - 1) & uint32(hashCode)
} }
func (locks *Locks)Lock(key string) { func (locks *Locks) Lock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.Lock() mu.Lock()
} }
func (locks *Locks)RLock(key string) { func (locks *Locks) RLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.RLock() mu.RLock()
} }
func (locks *Locks)UnLock(key string) { func (locks *Locks) UnLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.Unlock() mu.Unlock()
} }
func (locks *Locks)RUnLock(key string) { func (locks *Locks) RUnLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.RUnlock() mu.RUnlock()
} }
func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 { func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 {
indexMap := make(map[uint32]bool) indexMap := make(map[uint32]bool)
for _, key := range keys { for _, key := range keys {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
indexMap[index] = true indexMap[index] = true
} }
indices := make([]uint32, 0, len(indexMap)) indices := make([]uint32, 0, len(indexMap))
for index := range indexMap { for index := range indexMap {
indices = append(indices, index) indices = append(indices, index)
} }
sort.Slice(indices, func(i, j int) bool { sort.Slice(indices, func(i, j int) bool {
if !reverse { if !reverse {
return indices[i] < indices[j] return indices[i] < indices[j]
} else { } else {
return indices[i] > indices[j] return indices[i] > indices[j]
} }
}) })
return indices return indices
} }
func (locks *Locks)Locks(keys ...string) { func (locks *Locks) Locks(keys ...string) {
indices := locks.toLockIndices(keys, false) indices := locks.toLockIndices(keys, false)
for _, index := range indices { for _, index := range indices {
mu := locks.table[index] mu := locks.table[index]
mu.Lock() mu.Lock()
} }
} }
func (locks *Locks)RLocks(keys ...string) { func (locks *Locks) RLocks(keys ...string) {
indices := locks.toLockIndices(keys, false) indices := locks.toLockIndices(keys, false)
for _, index := range indices { for _, index := range indices {
mu := locks.table[index] mu := locks.table[index]
mu.RLock() mu.RLock()
} }
} }
func (locks *Locks) UnLocks(keys ...string) {
func (locks *Locks)UnLocks(keys ...string) { indices := locks.toLockIndices(keys, true)
indices := locks.toLockIndices(keys, true) for _, index := range indices {
for _, index := range indices { mu := locks.table[index]
mu := locks.table[index] mu.Unlock()
mu.Unlock() }
}
} }
func (locks *Locks)RUnLocks(keys ...string) { func (locks *Locks) RUnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true) indices := locks.toLockIndices(keys, true)
for _, index := range indices { for _, index := range indices {
mu := locks.table[index] mu := locks.table[index]
mu.RUnlock() mu.RUnlock()
} }
} }
func GoID() int { func GoID() int {
var buf [64]byte var buf [64]byte
n := runtime.Stack(buf[:], false) n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField) id, err := strconv.Atoi(idField)
if err != nil { if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err)) panic(fmt.Sprintf("cannot get goroutine id: %v", err))
} }
return id return id
} }
func debug(testing.T) { func debug(testing.T) {
lm := Locks{} lm := Locks{}
size := 10 size := 10
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(size) wg.Add(size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
go func(i int) { go func(i int) {
lm.Locks("1", "2") lm.Locks("1", "2")
println("go: " + strconv.Itoa(GoID())) println("go: " + strconv.Itoa(GoID()))
time.Sleep(time.Second) time.Sleep(time.Second)
println("go: " + strconv.Itoa(GoID())) println("go: " + strconv.Itoa(GoID()))
lm.UnLocks("1", "2") lm.UnLocks("1", "2")
wg.Done() wg.Done()
}(i) }(i)
} }
wg.Wait() wg.Wait()
} }

View File

@@ -1,32 +1,32 @@
package set package set
import ( import (
"strconv" "strconv"
"testing" "testing"
) )
func TestSet(t *testing.T) { func TestSet(t *testing.T) {
size := 10 size := 10
set := Make() set := Make()
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
set.Add(strconv.Itoa(i)) set.Add(strconv.Itoa(i))
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
ok := set.Has(strconv.Itoa(i)) ok := set.Has(strconv.Itoa(i))
if !ok { if !ok {
t.Error("expected true actual false, key: " + strconv.Itoa(i)) t.Error("expected true actual false, key: " + strconv.Itoa(i))
} }
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
ok := set.Remove(strconv.Itoa(i)) ok := set.Remove(strconv.Itoa(i))
if ok != 1 { if ok != 1 {
t.Error("expected true actual false, key: " + strconv.Itoa(i)) t.Error("expected true actual false, key: " + strconv.Itoa(i))
} }
} }
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
ok := set.Has(strconv.Itoa(i)) ok := set.Has(strconv.Itoa(i))
if ok { if ok {
t.Error("expected false actual true, key: " + strconv.Itoa(i)) t.Error("expected false actual true, key: " + strconv.Itoa(i))
} }
} }
} }

View File

@@ -1,8 +1,8 @@
package sortedset package sortedset
import ( import (
"errors" "errors"
"strconv" "strconv"
) )
/* /*
@@ -14,78 +14,78 @@ import (
*/ */
const ( const (
negativeInf int8 = -1 negativeInf int8 = -1
positiveInf int8 = 1 positiveInf int8 = 1
) )
type ScoreBorder struct { type ScoreBorder struct {
Inf int8 Inf int8
Value float64 Value float64
Exclude bool Exclude bool
} }
// if max.greater(score) then the score is within the upper border // if max.greater(score) then the score is within the upper border
// do not use min.greater() // do not use min.greater()
func (border *ScoreBorder)greater(value float64)bool { func (border *ScoreBorder) greater(value float64) bool {
if border.Inf == negativeInf { if border.Inf == negativeInf {
return false return false
} else if border.Inf == positiveInf { } else if border.Inf == positiveInf {
return true return true
} }
if border.Exclude { if border.Exclude {
return border.Value > value return border.Value > value
} else { } else {
return border.Value >= value return border.Value >= value
} }
} }
func (border *ScoreBorder)less(value float64)bool { func (border *ScoreBorder) less(value float64) bool {
if border.Inf == negativeInf { if border.Inf == negativeInf {
return true return true
} else if border.Inf == positiveInf { } else if border.Inf == positiveInf {
return false return false
} }
if border.Exclude { if border.Exclude {
return border.Value < value return border.Value < value
} else { } else {
return border.Value <= value return border.Value <= value
} }
} }
var positiveInfBorder = &ScoreBorder { var positiveInfBorder = &ScoreBorder{
Inf: positiveInf, Inf: positiveInf,
} }
var negativeInfBorder = &ScoreBorder { var negativeInfBorder = &ScoreBorder{
Inf: negativeInf, Inf: negativeInf,
} }
func ParseScoreBorder(s string)(*ScoreBorder, error) { func ParseScoreBorder(s string) (*ScoreBorder, error) {
if s == "inf" || s == "+inf" { if s == "inf" || s == "+inf" {
return positiveInfBorder, nil return positiveInfBorder, nil
} }
if s == "-inf" { if s == "-inf" {
return negativeInfBorder, nil return negativeInfBorder, nil
} }
if s[0] == '(' { if s[0] == '(' {
value, err := strconv.ParseFloat(s[1:], 64) value, err := strconv.ParseFloat(s[1:], 64)
if err != nil { if err != nil {
return nil, errors.New("ERR min or max is not a float") return nil, errors.New("ERR min or max is not a float")
} }
return &ScoreBorder{ return &ScoreBorder{
Inf: 0, Inf: 0,
Value: value, Value: value,
Exclude: true, Exclude: true,
}, nil }, nil
} else { } else {
value, err := strconv.ParseFloat(s, 64) value, err := strconv.ParseFloat(s, 64)
if err != nil { if err != nil {
return nil, errors.New("ERR min or max is not a float") return nil, errors.New("ERR min or max is not a float")
} }
return &ScoreBorder{ return &ScoreBorder{
Inf: 0, Inf: 0,
Value: value, Value: value,
Exclude: false, Exclude: false,
}, nil }, nil
} }
} }

View File

@@ -3,130 +3,129 @@ package sortedset
import "math/rand" import "math/rand"
const ( const (
maxLevel = 16 maxLevel = 16
) )
type Element struct { type Element struct {
Member string Member string
Score float64 Score float64
} }
// level aspect of a Node // level aspect of a Node
type Level struct { type Level struct {
forward *Node // forward node has greater score forward *Node // forward node has greater score
span int64 span int64
} }
type Node struct { type Node struct {
Element Element
backward *Node backward *Node
level []*Level // level[0] is base level level []*Level // level[0] is base level
} }
type skiplist struct { type skiplist struct {
header *Node header *Node
tail *Node tail *Node
length int64 length int64
level int16 level int16
} }
func makeNode(level int16, score float64, member string)*Node { func makeNode(level int16, score float64, member string) *Node {
n := &Node{ n := &Node{
Element: Element{ Element: Element{
Score: score, Score: score,
Member: member, Member: member,
}, },
level: make([]*Level, level), level: make([]*Level, level),
} }
for i := range n.level { for i := range n.level {
n.level[i] = new(Level) n.level[i] = new(Level)
} }
return n return n
} }
func makeSkiplist()*skiplist { func makeSkiplist() *skiplist {
return &skiplist{ return &skiplist{
level: 1, level: 1,
header: makeNode(maxLevel, 0, ""), header: makeNode(maxLevel, 0, ""),
} }
} }
func randomLevel() int16 { func randomLevel() int16 {
level := int16(1) level := int16(1)
for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) { for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) {
level++ level++
} }
if level < maxLevel { if level < maxLevel {
return level return level
} }
return maxLevel return maxLevel
} }
func (skiplist *skiplist)insert(member string, score float64)*Node { func (skiplist *skiplist) insert(member string, score float64) *Node {
update := make([]*Node, maxLevel) // link new node with node in `update` update := make([]*Node, maxLevel) // link new node with node in `update`
rank := make([]int64, maxLevel) rank := make([]int64, maxLevel)
// find position to insert // find position to insert
node := skiplist.header node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- { for i := skiplist.level - 1; i >= 0; i-- {
if i == skiplist.level - 1 { if i == skiplist.level-1 {
rank[i] = 0 rank[i] = 0
} else { } else {
rank[i] = rank[i + 1] // store rank that is crossed to reach the insert position rank[i] = rank[i+1] // store rank that is crossed to reach the insert position
} }
if node.level[i] != nil { if node.level[i] != nil {
// traverse the skip list // traverse the skip list
for node.level[i].forward != nil && for node.level[i].forward != nil &&
(node.level[i].forward.Score < score || (node.level[i].forward.Score < score ||
(node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key (node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key
rank[i] += node.level[i].span rank[i] += node.level[i].span
node = node.level[i].forward node = node.level[i].forward
} }
} }
update[i] = node update[i] = node
} }
level := randomLevel() level := randomLevel()
// extend skiplist level // extend skiplist level
if level > skiplist.level { if level > skiplist.level {
for i := skiplist.level; i < level; i++ { for i := skiplist.level; i < level; i++ {
rank[i] = 0 rank[i] = 0
update[i] = skiplist.header update[i] = skiplist.header
update[i].level[i].span = skiplist.length update[i].level[i].span = skiplist.length
} }
skiplist.level = level skiplist.level = level
} }
// make node and link into skiplist // make node and link into skiplist
node = makeNode(level, score, member) node = makeNode(level, score, member)
for i := int16(0); i < level; i++ { for i := int16(0); i < level; i++ {
node.level[i].forward = update[i].level[i].forward node.level[i].forward = update[i].level[i].forward
update[i].level[i].forward = node update[i].level[i].forward = node
// update span covered by update[i] as node is inserted here // update span covered by update[i] as node is inserted here
node.level[i].span = update[i].level[i].span - (rank[0] - rank[i]) node.level[i].span = update[i].level[i].span - (rank[0] - rank[i])
update[i].level[i].span = (rank[0] - rank[i]) + 1 update[i].level[i].span = (rank[0] - rank[i]) + 1
} }
// increment span for untouched levels // increment span for untouched levels
for i := level; i < skiplist.level; i++ { for i := level; i < skiplist.level; i++ {
update[i].level[i].span++ update[i].level[i].span++
} }
// set backward node // set backward node
if update[0] == skiplist.header { if update[0] == skiplist.header {
node.backward = nil node.backward = nil
} else { } else {
node.backward = update[0] node.backward = update[0]
} }
if node.level[0].forward != nil { if node.level[0].forward != nil {
node.level[0].forward.backward = node node.level[0].forward.backward = node
} else { } else {
skiplist.tail = node skiplist.tail = node
} }
skiplist.length++ skiplist.length++
return node return node
} }
/* /*
@@ -134,212 +133,212 @@ func (skiplist *skiplist)insert(member string, score float64)*Node {
* param update: backward node (of target) * param update: backward node (of target)
*/ */
func (skiplist *skiplist) removeNode(node *Node, update []*Node) { func (skiplist *skiplist) removeNode(node *Node, update []*Node) {
for i := int16(0); i < skiplist.level; i++ { for i := int16(0); i < skiplist.level; i++ {
if update[i].level[i].forward == node { if update[i].level[i].forward == node {
update[i].level[i].span += node.level[i].span - 1 update[i].level[i].span += node.level[i].span - 1
update[i].level[i].forward = node.level[i].forward update[i].level[i].forward = node.level[i].forward
} else { } else {
update[i].level[i].span-- update[i].level[i].span--
} }
} }
if node.level[0].forward != nil { if node.level[0].forward != nil {
node.level[0].forward.backward = node.backward node.level[0].forward.backward = node.backward
} else { } else {
skiplist.tail = node.backward skiplist.tail = node.backward
} }
for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil { for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil {
skiplist.level-- skiplist.level--
} }
skiplist.length-- skiplist.length--
} }
/* /*
* return: has found and removed node * return: has found and removed node
*/ */
func (skiplist *skiplist) remove(member string, score float64)bool { func (skiplist *skiplist) remove(member string, score float64) bool {
/* /*
* find backward node (of target) or last node of each level * find backward node (of target) or last node of each level
* their forward need to be updated * their forward need to be updated
*/ */
update := make([]*Node, maxLevel) update := make([]*Node, maxLevel)
node := skiplist.header node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- { for i := skiplist.level - 1; i >= 0; i-- {
for node.level[i].forward != nil && for node.level[i].forward != nil &&
(node.level[i].forward.Score < score || (node.level[i].forward.Score < score ||
(node.level[i].forward.Score == score && (node.level[i].forward.Score == score &&
node.level[i].forward.Member < member)) { node.level[i].forward.Member < member)) {
node = node.level[i].forward node = node.level[i].forward
} }
update[i] = node update[i] = node
} }
node = node.level[0].forward node = node.level[0].forward
if node != nil && score == node.Score && node.Member == member { if node != nil && score == node.Score && node.Member == member {
skiplist.removeNode(node, update) skiplist.removeNode(node, update)
// free x // free x
return true return true
} }
return false return false
} }
/* /*
* return: 1 based rank, 0 means member not found * return: 1 based rank, 0 means member not found
*/ */
func (skiplist *skiplist) getRank(member string, score float64)int64 { func (skiplist *skiplist) getRank(member string, score float64) int64 {
var rank int64 = 0 var rank int64 = 0
x := skiplist.header x := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- { for i := skiplist.level - 1; i >= 0; i-- {
for x.level[i].forward != nil && for x.level[i].forward != nil &&
(x.level[i].forward.Score < score || (x.level[i].forward.Score < score ||
(x.level[i].forward.Score == score && (x.level[i].forward.Score == score &&
x.level[i].forward.Member <= member)) { x.level[i].forward.Member <= member)) {
rank += x.level[i].span rank += x.level[i].span
x = x.level[i].forward x = x.level[i].forward
} }
/* x might be equal to zsl->header, so test if obj is non-NULL */ /* x might be equal to zsl->header, so test if obj is non-NULL */
if x.Member == member { if x.Member == member {
return rank return rank
} }
} }
return 0 return 0
} }
/* /*
* 1-based rank * 1-based rank
*/ */
func (skiplist *skiplist) getByRank(rank int64)*Node { func (skiplist *skiplist) getByRank(rank int64) *Node {
var i int64 = 0 var i int64 = 0
n := skiplist.header n := skiplist.header
// scan from top level // scan from top level
for level := skiplist.level - 1; level >= 0; level-- { for level := skiplist.level - 1; level >= 0; level-- {
for n.level[level].forward != nil && (i+n.level[level].span) <= rank { for n.level[level].forward != nil && (i+n.level[level].span) <= rank {
i += n.level[level].span i += n.level[level].span
n = n.level[level].forward n = n.level[level].forward
} }
if i == rank { if i == rank {
return n return n
} }
} }
return nil return nil
} }
func (skiplist *skiplist) hasInRange(min *ScoreBorder, max *ScoreBorder) bool { func (skiplist *skiplist) hasInRange(min *ScoreBorder, max *ScoreBorder) bool {
// min & max = empty // min & max = empty
if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) { if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) {
return false return false
} }
// min > tail // min > tail
n := skiplist.tail n := skiplist.tail
if n == nil || !min.less(n.Score) { if n == nil || !min.less(n.Score) {
return false return false
} }
// max < head // max < head
n = skiplist.header.level[0].forward n = skiplist.header.level[0].forward
if n == nil || !max.greater(n.Score) { if n == nil || !max.greater(n.Score) {
return false return false
} }
return true return true
} }
func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node { func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
if !skiplist.hasInRange(min, max) { if !skiplist.hasInRange(min, max) {
return nil return nil
} }
n := skiplist.header n := skiplist.header
// scan from top level // scan from top level
for level := skiplist.level - 1; level >= 0; level-- { for level := skiplist.level - 1; level >= 0; level-- {
// if forward is not in range than move forward // if forward is not in range than move forward
for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) { for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) {
n = n.level[level].forward n = n.level[level].forward
} }
} }
/* This is an inner range, so the next node cannot be NULL. */ /* This is an inner range, so the next node cannot be NULL. */
n = n.level[0].forward n = n.level[0].forward
if !max.greater(n.Score) { if !max.greater(n.Score) {
return nil return nil
} }
return n return n
} }
func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node { func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
if !skiplist.hasInRange(min, max) { if !skiplist.hasInRange(min, max) {
return nil return nil
} }
n := skiplist.header n := skiplist.header
// scan from top level // scan from top level
for level := skiplist.level - 1; level >= 0; level-- { for level := skiplist.level - 1; level >= 0; level-- {
for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) { for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) {
n = n.level[level].forward n = n.level[level].forward
} }
} }
if !min.less(n.Score) { if !min.less(n.Score) {
return nil return nil
} }
return n return n
} }
/* /*
* return removed elements * return removed elements
*/ */
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)(removed []*Element) { func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder) (removed []*Element) {
update := make([]*Node, maxLevel) update := make([]*Node, maxLevel)
removed = make([]*Element, 0) removed = make([]*Element, 0)
// find backward nodes (of target range) or last node of each level // find backward nodes (of target range) or last node of each level
node := skiplist.header node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- { for i := skiplist.level - 1; i >= 0; i-- {
for node.level[i].forward != nil { for node.level[i].forward != nil {
if min.less(node.level[i].forward.Score) { // already in range if min.less(node.level[i].forward.Score) { // already in range
break break
} }
node = node.level[i].forward node = node.level[i].forward
} }
update[i] = node update[i] = node
} }
// node is the first one within range // node is the first one within range
node = node.level[0].forward node = node.level[0].forward
// remove nodes in range // remove nodes in range
for node != nil { for node != nil {
if !max.greater(node.Score) { // already out of range if !max.greater(node.Score) { // already out of range
break break
} }
next := node.level[0].forward next := node.level[0].forward
removedElement := node.Element removedElement := node.Element
removed = append(removed, &removedElement) removed = append(removed, &removedElement)
skiplist.removeNode(node, update) skiplist.removeNode(node, update)
node = next node = next
} }
return removed return removed
} }
// 1-based rank, including start, exclude stop // 1-based rank, including start, exclude stop
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64)(removed []*Element) { func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64) (removed []*Element) {
var i int64 = 0 // rank of iterator var i int64 = 0 // rank of iterator
update := make([]*Node, maxLevel) update := make([]*Node, maxLevel)
removed = make([]*Element, 0) removed = make([]*Element, 0)
// scan from top level // scan from top level
node := skiplist.header node := skiplist.header
for level := skiplist.level - 1; level >= 0; level-- { for level := skiplist.level - 1; level >= 0; level-- {
for node.level[level].forward != nil && (i+node.level[level].span) < start { for node.level[level].forward != nil && (i+node.level[level].span) < start {
i += node.level[level].span i += node.level[level].span
node = node.level[level].forward node = node.level[level].forward
} }
update[level] = node update[level] = node
} }
i++ i++
node = node.level[0].forward // first node in range node = node.level[0].forward // first node in range
// remove nodes in range // remove nodes in range
for node != nil && i < stop { for node != nil && i < stop {
next := node.level[0].forward next := node.level[0].forward
removedElement := node.Element removedElement := node.Element
removed = append(removed, &removedElement) removed = append(removed, &removedElement)
skiplist.removeNode(node, update) skiplist.removeNode(node, update)
node = next node = next
i++ i++
} }
return removed return removed
} }

View File

@@ -1,227 +1,226 @@
package sortedset package sortedset
import ( import (
"strconv" "strconv"
) )
type SortedSet struct { type SortedSet struct {
dict map[string]*Element dict map[string]*Element
skiplist *skiplist skiplist *skiplist
} }
func Make()*SortedSet { func Make() *SortedSet {
return &SortedSet{ return &SortedSet{
dict: make(map[string]*Element), dict: make(map[string]*Element),
skiplist: makeSkiplist(), skiplist: makeSkiplist(),
} }
} }
/* /*
* return: has inserted new node * return: has inserted new node
*/ */
func (sortedSet *SortedSet)Add(member string, score float64)bool { func (sortedSet *SortedSet) Add(member string, score float64) bool {
element, ok := sortedSet.dict[member] element, ok := sortedSet.dict[member]
sortedSet.dict[member] = &Element{ sortedSet.dict[member] = &Element{
Member: member, Member: member,
Score: score, Score: score,
} }
if ok { if ok {
if score != element.Score { if score != element.Score {
sortedSet.skiplist.remove(member, score) sortedSet.skiplist.remove(member, score)
sortedSet.skiplist.insert(member, score) sortedSet.skiplist.insert(member, score)
} }
return false return false
} else { } else {
sortedSet.skiplist.insert(member, score) sortedSet.skiplist.insert(member, score)
return true return true
} }
} }
func (sortedSet *SortedSet) Len()int64 { func (sortedSet *SortedSet) Len() int64 {
return int64(len(sortedSet.dict)) return int64(len(sortedSet.dict))
} }
func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) { func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) {
element, ok = sortedSet.dict[member] element, ok = sortedSet.dict[member]
if !ok { if !ok {
return nil, false return nil, false
} }
return element, true return element, true
} }
func (sortedSet *SortedSet) Remove(member string)bool { func (sortedSet *SortedSet) Remove(member string) bool {
v, ok := sortedSet.dict[member] v, ok := sortedSet.dict[member]
if ok { if ok {
sortedSet.skiplist.remove(member, v.Score) sortedSet.skiplist.remove(member, v.Score)
delete(sortedSet.dict, member) delete(sortedSet.dict, member)
return true return true
} }
return false return false
} }
/** /**
* get 0-based rank * get 0-based rank
*/ */
func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) { func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) {
element, ok := sortedSet.dict[member] element, ok := sortedSet.dict[member]
if !ok { if !ok {
return -1 return -1
} }
r := sortedSet.skiplist.getRank(member, element.Score) r := sortedSet.skiplist.getRank(member, element.Score)
if desc { if desc {
r = sortedSet.skiplist.length - r r = sortedSet.skiplist.length - r
} else { } else {
r-- r--
} }
return r return r
} }
/** /**
* traverse [start, stop), 0-based rank * traverse [start, stop), 0-based rank
*/ */
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element)bool) { func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element) bool) {
size := int64(sortedSet.Len()) size := int64(sortedSet.Len())
if start < 0 || start >= size { if start < 0 || start >= size {
panic("illegal start " + strconv.FormatInt(start, 10)) panic("illegal start " + strconv.FormatInt(start, 10))
} }
if stop < start || stop > size { if stop < start || stop > size {
panic("illegal end " + strconv.FormatInt(stop, 10)) panic("illegal end " + strconv.FormatInt(stop, 10))
} }
// find start node // find start node
var node *Node var node *Node
if desc { if desc {
node = sortedSet.skiplist.tail node = sortedSet.skiplist.tail
if start > 0 { if start > 0 {
node = sortedSet.skiplist.getByRank(int64(size - start)) node = sortedSet.skiplist.getByRank(int64(size - start))
} }
} else { } else {
node = sortedSet.skiplist.header.level[0].forward node = sortedSet.skiplist.header.level[0].forward
if start > 0 { if start > 0 {
node = sortedSet.skiplist.getByRank(int64(start + 1)) node = sortedSet.skiplist.getByRank(int64(start + 1))
} }
} }
sliceSize := int(stop - start) sliceSize := int(stop - start)
for i := 0; i < sliceSize; i++ { for i := 0; i < sliceSize; i++ {
if !consumer(&node.Element) { if !consumer(&node.Element) {
break break
} }
if desc { if desc {
node = node.backward node = node.backward
} else { } else {
node = node.level[0].forward node = node.level[0].forward
} }
} }
} }
/** /**
* return [start, stop), 0-based rank * return [start, stop), 0-based rank
* assert start in [0, size), stop in [start, size] * assert start in [0, size), stop in [start, size]
*/ */
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool)[]*Element { func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element {
sliceSize := int(stop - start) sliceSize := int(stop - start)
slice := make([]*Element, sliceSize) slice := make([]*Element, sliceSize)
i := 0 i := 0
sortedSet.ForEach(start, stop, desc, func(element *Element)bool { sortedSet.ForEach(start, stop, desc, func(element *Element) bool {
slice[i] = element slice[i] = element
i++ i++
return true return true
}) })
return slice return slice
} }
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder)int64 { func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 {
var i int64 = 0 var i int64 = 0
// ascending order // ascending order
sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool { sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool {
gtMin := min.less(element.Score) // greater than min gtMin := min.less(element.Score) // greater than min
if !gtMin { if !gtMin {
// has not into range, continue foreach // has not into range, continue foreach
return true return true
} }
ltMax := max.greater(element.Score) // less than max ltMax := max.greater(element.Score) // less than max
if !ltMax { if !ltMax {
// break through score border, break foreach // break through score border, break foreach
return false return false
} }
// gtMin && ltMax // gtMin && ltMax
i++ i++
return true return true
}) })
return i return i
} }
func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) { func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) {
// find start node // find start node
var node *Node var node *Node
if desc { if desc {
node = sortedSet.skiplist.getLastInScoreRange(min, max) node = sortedSet.skiplist.getLastInScoreRange(min, max)
} else { } else {
node = sortedSet.skiplist.getFirstInScoreRange(min, max) node = sortedSet.skiplist.getFirstInScoreRange(min, max)
} }
for node != nil && offset > 0 { for node != nil && offset > 0 {
if desc { if desc {
node = node.backward node = node.backward
} else { } else {
node = node.level[0].forward node = node.level[0].forward
} }
offset-- offset--
} }
// A negative limit returns all elements from the offset // A negative limit returns all elements from the offset
for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ { for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ {
if !consumer(&node.Element) { if !consumer(&node.Element) {
break break
} }
if desc { if desc {
node = node.backward node = node.backward
} else { } else {
node = node.level[0].forward node = node.level[0].forward
} }
if node == nil { if node == nil {
break break
} }
gtMin := min.less(node.Element.Score) // greater than min gtMin := min.less(node.Element.Score) // greater than min
ltMax := max.greater(node.Element.Score) ltMax := max.greater(node.Element.Score)
if !gtMin || !ltMax { if !gtMin || !ltMax {
break // break through score border break // break through score border
} }
} }
} }
/* /*
* param limit: <0 means no limit * param limit: <0 means no limit
*/ */
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool)[]*Element { func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool) []*Element {
if limit == 0 || offset < 0{ if limit == 0 || offset < 0 {
return make([]*Element, 0) return make([]*Element, 0)
} }
slice := make([]*Element, 0) slice := make([]*Element, 0)
sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool { sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool {
slice = append(slice, element) slice = append(slice, element)
return true return true
}) })
return slice return slice
} }
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder)int64 { func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) int64 {
removed := sortedSet.skiplist.RemoveRangeByScore(min, max) removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
for _, element := range removed { for _, element := range removed {
delete(sortedSet.dict, element.Member) delete(sortedSet.dict, element.Member)
} }
return int64(len(removed)) return int64(len(removed))
} }
/* /*
* 0-based rank, [start, stop) * 0-based rank, [start, stop)
*/ */
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64)int64 { func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 {
removed := sortedSet.skiplist.RemoveRangeByRank(start + 1, stop + 1) removed := sortedSet.skiplist.RemoveRangeByRank(start+1, stop+1)
for _, element := range removed { for _, element := range removed {
delete(sortedSet.dict, element.Member) delete(sortedSet.dict, element.Member)
} }
return int64(len(removed)) return int64(len(removed))
} }

View File

@@ -1,28 +1,28 @@
package utils package utils
func Equals(a interface{}, b interface{})bool { func Equals(a interface{}, b interface{}) bool {
sliceA, okA := a.([]byte) sliceA, okA := a.([]byte)
sliceB, okB := b.([]byte) sliceB, okB := b.([]byte)
if okA && okB { if okA && okB {
return BytesEquals(sliceA, sliceB) return BytesEquals(sliceA, sliceB)
} }
return a == b return a == b
} }
func BytesEquals(a []byte, b []byte) bool { func BytesEquals(a []byte, b []byte) bool {
if (a == nil && b != nil) || (a != nil && b == nil) { if (a == nil && b != nil) || (a != nil && b == nil) {
return false return false
} }
if len(a) != len(b) { if len(a) != len(b) {
return false return false
} }
size := len(a) size := len(a)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
av := a[i] av := a[i]
bv := b[i] bv := b[i]
if av != bv { if av != bv {
return false return false
} }
} }
return true return true
} }

View File

@@ -2,7 +2,7 @@ package db
import ( import (
"bufio" "bufio"
"github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/datastruct/dict" "github.com/HDT3213/godis/src/datastruct/dict"
List "github.com/HDT3213/godis/src/datastruct/list" List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/datastruct/lock" "github.com/HDT3213/godis/src/datastruct/lock"
@@ -29,18 +29,18 @@ func makeExpireCmd(key string, expireAt time.Time) *reply.MultiBulkReply {
} }
func makeAofCmd(cmd string, args [][]byte) *reply.MultiBulkReply { func makeAofCmd(cmd string, args [][]byte) *reply.MultiBulkReply {
params := make([][]byte, len(args)+1) params := make([][]byte, len(args)+1)
copy(params[1:], args) copy(params[1:], args)
params[0] = []byte(cmd) params[0] = []byte(cmd)
return reply.MakeMultiBulkReply(params) return reply.MakeMultiBulkReply(params)
} }
// send command to aof // send command to aof
func (db *DB) AddAof(args *reply.MultiBulkReply) { func (db *DB) AddAof(args *reply.MultiBulkReply) {
// aofChan == nil when loadAof // aofChan == nil when loadAof
if config.Properties.AppendOnly && db.aofChan != nil { if config.Properties.AppendOnly && db.aofChan != nil {
db.aofChan <- args db.aofChan <- args
} }
} }
// listen aof channel and write into file // listen aof channel and write into file
@@ -72,12 +72,12 @@ func trim(msg []byte) string {
// read aof file // read aof file
func (db *DB) loadAof(maxBytes int) { func (db *DB) loadAof(maxBytes int) {
// delete aofChan to prevent write again // delete aofChan to prevent write again
aofChan := db.aofChan aofChan := db.aofChan
db.aofChan = nil db.aofChan = nil
defer func(aofChan chan *reply.MultiBulkReply) { defer func(aofChan chan *reply.MultiBulkReply) {
db.aofChan = aofChan db.aofChan = aofChan
}(aofChan) }(aofChan)
file, err := os.Open(db.aofFilename) file, err := os.Open(db.aofFilename)
if err != nil { if err != nil {

View File

@@ -1,297 +1,297 @@
package db package db
import ( import (
"fmt" "fmt"
"github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/datastruct/dict" "github.com/HDT3213/godis/src/datastruct/dict"
List "github.com/HDT3213/godis/src/datastruct/list" List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/datastruct/lock" "github.com/HDT3213/godis/src/datastruct/lock"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/pubsub" "github.com/HDT3213/godis/src/pubsub"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"os" "os"
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
type DataEntity struct { type DataEntity struct {
Data interface{} Data interface{}
} }
const ( const (
dataDictSize = 1 << 16 dataDictSize = 1 << 16
ttlDictSize = 1 << 10 ttlDictSize = 1 << 10
lockerSize = 128 lockerSize = 128
aofQueueSize = 1 << 16 aofQueueSize = 1 << 16
) )
// args don't include cmd line // args don't include cmd line
type CmdFunc func(db *DB, args [][]byte) redis.Reply type CmdFunc func(db *DB, args [][]byte) redis.Reply
type DB struct { type DB struct {
// key -> DataEntity // key -> DataEntity
Data dict.Dict Data dict.Dict
// key -> expireTime (time.Time) // key -> expireTime (time.Time)
TTLMap dict.Dict TTLMap dict.Dict
// channel -> list<*client> // channel -> list<*client>
SubMap dict.Dict SubMap dict.Dict
// dict will ensure thread safety of its method // dict will ensure thread safety of its method
// use this mutex for complicated command only, eg. rpush, incr ... // use this mutex for complicated command only, eg. rpush, incr ...
Locker *lock.Locks Locker *lock.Locks
// TimerTask interval // TimerTask interval
interval time.Duration interval time.Duration
stopWorld sync.WaitGroup stopWorld sync.WaitGroup
hub *pubsub.Hub hub *pubsub.Hub
// main goroutine send commands to aof goroutine through aofChan // main goroutine send commands to aof goroutine through aofChan
aofChan chan *reply.MultiBulkReply aofChan chan *reply.MultiBulkReply
aofFile *os.File aofFile *os.File
aofFilename string aofFilename string
aofRewriteChan chan *reply.MultiBulkReply aofRewriteChan chan *reply.MultiBulkReply
pausingAof sync.RWMutex pausingAof sync.RWMutex
} }
var router = MakeRouter() var router = MakeRouter()
func MakeDB() *DB { func MakeDB() *DB {
db := &DB{ db := &DB{
Data: dict.MakeConcurrent(dataDictSize), Data: dict.MakeConcurrent(dataDictSize),
TTLMap: dict.MakeConcurrent(ttlDictSize), TTLMap: dict.MakeConcurrent(ttlDictSize),
Locker: lock.Make(lockerSize), Locker: lock.Make(lockerSize),
interval: 5 * time.Second, interval: 5 * time.Second,
hub: pubsub.MakeHub(), hub: pubsub.MakeHub(),
} }
// aof // aof
if config.Properties.AppendOnly { if config.Properties.AppendOnly {
db.aofFilename = config.Properties.AppendFilename db.aofFilename = config.Properties.AppendFilename
db.loadAof(0) db.loadAof(0)
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600) aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
if err != nil { if err != nil {
logger.Warn(err) logger.Warn(err)
} else { } else {
db.aofFile = aofFile db.aofFile = aofFile
db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize) db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize)
} }
go func() { go func() {
db.handleAof() db.handleAof()
}() }()
} }
// start timer // start timer
db.TimerTask() db.TimerTask()
return db return db
} }
func (db *DB) Close() { func (db *DB) Close() {
if db.aofFile != nil { if db.aofFile != nil {
err := db.aofFile.Close() err := db.aofFile.Close()
if err != nil { if err != nil {
logger.Warn(err) logger.Warn(err)
} }
} }
} }
func (db *DB) Exec(c redis.Connection, args [][]byte) (result redis.Reply) { func (db *DB) Exec(c redis.Connection, args [][]byte) (result redis.Reply) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack()))) logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
result = &reply.UnknownErrReply{} result = &reply.UnknownErrReply{}
} }
}() }()
cmd := strings.ToLower(string(args[0])) cmd := strings.ToLower(string(args[0]))
// special commands // special commands
if cmd == "subscribe" { if cmd == "subscribe" {
if len(args) < 2 { if len(args) < 2 {
return &reply.ArgNumErrReply{Cmd: "subscribe"} return &reply.ArgNumErrReply{Cmd: "subscribe"}
} }
return pubsub.Subscribe(db.hub, c, args[1:]) return pubsub.Subscribe(db.hub, c, args[1:])
} else if cmd == "publish" { } else if cmd == "publish" {
return pubsub.Publish(db.hub, args[1:]) return pubsub.Publish(db.hub, args[1:])
} else if cmd == "unsubscribe" { } else if cmd == "unsubscribe" {
return pubsub.UnSubscribe(db.hub, c, args[1:]) return pubsub.UnSubscribe(db.hub, c, args[1:])
} else if cmd == "bgrewriteaof" { } else if cmd == "bgrewriteaof" {
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go // aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
reply := BGRewriteAOF(db, args[1:]) reply := BGRewriteAOF(db, args[1:])
return reply return reply
} }
// normal commands // normal commands
cmdFunc, ok := router[cmd] cmdFunc, ok := router[cmd]
if !ok { if !ok {
return reply.MakeErrReply("ERR unknown command '" + cmd + "'") return reply.MakeErrReply("ERR unknown command '" + cmd + "'")
} }
if len(args) > 1 { if len(args) > 1 {
result = cmdFunc(db, args[1:]) result = cmdFunc(db, args[1:])
} else { } else {
result = cmdFunc(db, [][]byte{}) result = cmdFunc(db, [][]byte{})
} }
// aof // aof
return return
} }
/* ---- Data Access ----- */ /* ---- Data Access ----- */
func (db *DB) Get(key string) (*DataEntity, bool) { func (db *DB) Get(key string) (*DataEntity, bool) {
db.stopWorld.Wait() db.stopWorld.Wait()
raw, ok := db.Data.Get(key) raw, ok := db.Data.Get(key)
if !ok { if !ok {
return nil, false return nil, false
} }
if db.IsExpired(key) { if db.IsExpired(key) {
return nil, false return nil, false
} }
entity, _ := raw.(*DataEntity) entity, _ := raw.(*DataEntity)
return entity, true return entity, true
} }
func (db *DB) Put(key string, entity *DataEntity) int { func (db *DB) Put(key string, entity *DataEntity) int {
db.stopWorld.Wait() db.stopWorld.Wait()
return db.Data.Put(key, entity) return db.Data.Put(key, entity)
} }
func (db *DB) PutIfExists(key string, entity *DataEntity) int { func (db *DB) PutIfExists(key string, entity *DataEntity) int {
db.stopWorld.Wait() db.stopWorld.Wait()
return db.Data.PutIfExists(key, entity) return db.Data.PutIfExists(key, entity)
} }
func (db *DB) PutIfAbsent(key string, entity *DataEntity) int { func (db *DB) PutIfAbsent(key string, entity *DataEntity) int {
db.stopWorld.Wait() db.stopWorld.Wait()
return db.Data.PutIfAbsent(key, entity) return db.Data.PutIfAbsent(key, entity)
} }
func (db *DB) Remove(key string) { func (db *DB) Remove(key string) {
db.stopWorld.Wait() db.stopWorld.Wait()
db.Data.Remove(key) db.Data.Remove(key)
db.TTLMap.Remove(key) db.TTLMap.Remove(key)
} }
func (db *DB) Removes(keys ...string) (deleted int) { func (db *DB) Removes(keys ...string) (deleted int) {
db.stopWorld.Wait() db.stopWorld.Wait()
deleted = 0 deleted = 0
for _, key := range keys { for _, key := range keys {
_, exists := db.Data.Get(key) _, exists := db.Data.Get(key)
if exists { if exists {
db.Data.Remove(key) db.Data.Remove(key)
db.TTLMap.Remove(key) db.TTLMap.Remove(key)
deleted++ deleted++
} }
} }
return deleted return deleted
} }
func (db *DB) Flush() { func (db *DB) Flush() {
db.stopWorld.Add(1) db.stopWorld.Add(1)
defer db.stopWorld.Done() defer db.stopWorld.Done()
db.Data = dict.MakeConcurrent(dataDictSize) db.Data = dict.MakeConcurrent(dataDictSize)
db.TTLMap = dict.MakeConcurrent(ttlDictSize) db.TTLMap = dict.MakeConcurrent(ttlDictSize)
db.Locker = lock.Make(lockerSize) db.Locker = lock.Make(lockerSize)
} }
/* ---- Lock Function ----- */ /* ---- Lock Function ----- */
func (db *DB) Lock(key string) { func (db *DB) Lock(key string) {
db.Locker.Lock(key) db.Locker.Lock(key)
} }
func (db *DB) RLock(key string) { func (db *DB) RLock(key string) {
db.Locker.RLock(key) db.Locker.RLock(key)
} }
func (db *DB) UnLock(key string) { func (db *DB) UnLock(key string) {
db.Locker.UnLock(key) db.Locker.UnLock(key)
} }
func (db *DB) RUnLock(key string) { func (db *DB) RUnLock(key string) {
db.Locker.RUnLock(key) db.Locker.RUnLock(key)
} }
func (db *DB) Locks(keys ...string) { func (db *DB) Locks(keys ...string) {
db.Locker.Locks(keys...) db.Locker.Locks(keys...)
} }
func (db *DB) RLocks(keys ...string) { func (db *DB) RLocks(keys ...string) {
db.Locker.RLocks(keys...) db.Locker.RLocks(keys...)
} }
func (db *DB) UnLocks(keys ...string) { func (db *DB) UnLocks(keys ...string) {
db.Locker.UnLocks(keys...) db.Locker.UnLocks(keys...)
} }
func (db *DB) RUnLocks(keys ...string) { func (db *DB) RUnLocks(keys ...string) {
db.Locker.RUnLocks(keys...) db.Locker.RUnLocks(keys...)
} }
/* ---- TTL Functions ---- */ /* ---- TTL Functions ---- */
func (db *DB) Expire(key string, expireTime time.Time) { func (db *DB) Expire(key string, expireTime time.Time) {
db.stopWorld.Wait() db.stopWorld.Wait()
db.TTLMap.Put(key, expireTime) db.TTLMap.Put(key, expireTime)
} }
func (db *DB) Persist(key string) { func (db *DB) Persist(key string) {
db.stopWorld.Wait() db.stopWorld.Wait()
db.TTLMap.Remove(key) db.TTLMap.Remove(key)
} }
func (db *DB) IsExpired(key string) bool { func (db *DB) IsExpired(key string) bool {
rawExpireTime, ok := db.TTLMap.Get(key) rawExpireTime, ok := db.TTLMap.Get(key)
if !ok { if !ok {
return false return false
} }
expireTime, _ := rawExpireTime.(time.Time) expireTime, _ := rawExpireTime.(time.Time)
expired := time.Now().After(expireTime) expired := time.Now().After(expireTime)
if expired { if expired {
db.Remove(key) db.Remove(key)
} }
return expired return expired
} }
func (db *DB) CleanExpired() { func (db *DB) CleanExpired() {
now := time.Now() now := time.Now()
toRemove := &List.LinkedList{} toRemove := &List.LinkedList{}
db.TTLMap.ForEach(func(key string, val interface{}) bool { db.TTLMap.ForEach(func(key string, val interface{}) bool {
expireTime, _ := val.(time.Time) expireTime, _ := val.(time.Time)
if now.After(expireTime) { if now.After(expireTime) {
// expired // expired
db.Data.Remove(key) db.Data.Remove(key)
toRemove.Add(key) toRemove.Add(key)
} }
return true return true
}) })
toRemove.ForEach(func(i int, val interface{}) bool { toRemove.ForEach(func(i int, val interface{}) bool {
key, _ := val.(string) key, _ := val.(string)
db.TTLMap.Remove(key) db.TTLMap.Remove(key)
return true return true
}) })
} }
func (db *DB) TimerTask() { func (db *DB) TimerTask() {
ticker := time.NewTicker(db.interval) ticker := time.NewTicker(db.interval)
go func() { go func() {
for range ticker.C { for range ticker.C {
db.CleanExpired() db.CleanExpired()
} }
}() }()
} }
/* ---- Subscribe Functions ---- */ /* ---- Subscribe Functions ---- */
func (db *DB) AfterClientClose(c redis.Connection) { func (db *DB) AfterClientClose(c redis.Connection) {
pubsub.UnsubscribeAll(db.hub, c) pubsub.UnsubscribeAll(db.hub, c)
} }

View File

@@ -1,422 +1,422 @@
package db package db
import ( import (
Dict "github.com/HDT3213/godis/src/datastruct/dict" Dict "github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"strconv" "strconv"
) )
func (db *DB) getAsDict(key string) (Dict.Dict, reply.ErrorReply) { func (db *DB) getAsDict(key string) (Dict.Dict, reply.ErrorReply) {
entity, exists := db.Get(key) entity, exists := db.Get(key)
if !exists { if !exists {
return nil, nil return nil, nil
} }
dict, ok := entity.Data.(Dict.Dict) dict, ok := entity.Data.(Dict.Dict)
if !ok { if !ok {
return nil, &reply.WrongTypeErrReply{} return nil, &reply.WrongTypeErrReply{}
} }
return dict, nil return dict, nil
} }
func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply reply.ErrorReply) { func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply reply.ErrorReply) {
dict, errReply = db.getAsDict(key) dict, errReply = db.getAsDict(key)
if errReply != nil { if errReply != nil {
return nil, false, errReply return nil, false, errReply
} }
inited = false inited = false
if dict == nil { if dict == nil {
dict = Dict.MakeSimple() dict = Dict.MakeSimple()
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: dict, Data: dict,
}) })
inited = true inited = true
} }
return dict, inited, nil return dict, inited, nil
} }
func HSet(db *DB, args [][]byte) redis.Reply { func HSet(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command")
} }
key := string(args[0]) key := string(args[0])
field := string(args[1]) field := string(args[1])
value := args[2] value := args[2]
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get or init entity // get or init entity
dict, _, errReply := db.getOrInitDict(key) dict, _, errReply := db.getOrInitDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
result := dict.Put(field, value) result := dict.Put(field, value)
db.AddAof(makeAofCmd("hset", args)) db.AddAof(makeAofCmd("hset", args))
return reply.MakeIntReply(int64(result)) return reply.MakeIntReply(int64(result))
} }
func HSetNX(db *DB, args [][]byte) redis.Reply { func HSetNX(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command")
} }
key := string(args[0]) key := string(args[0])
field := string(args[1]) field := string(args[1])
value := args[2] value := args[2]
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
dict, _, errReply := db.getOrInitDict(key) dict, _, errReply := db.getOrInitDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
result := dict.PutIfAbsent(field, value) result := dict.PutIfAbsent(field, value)
if result > 0 { if result > 0 {
db.AddAof(makeAofCmd("hsetnx", args)) db.AddAof(makeAofCmd("hsetnx", args))
} }
return reply.MakeIntReply(int64(result)) return reply.MakeIntReply(int64(result))
} }
func HGet(db *DB, args [][]byte) redis.Reply { func HGet(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command")
} }
key := string(args[0]) key := string(args[0])
field := string(args[1]) field := string(args[1])
// get entity // get entity
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if dict == nil { if dict == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
raw, exists := dict.Get(field) raw, exists := dict.Get(field)
if !exists { if !exists {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
value, _ := raw.([]byte) value, _ := raw.([]byte)
return reply.MakeBulkReply(value) return reply.MakeBulkReply(value)
} }
func HExists(db *DB, args [][]byte) redis.Reply { func HExists(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command")
} }
key := string(args[0]) key := string(args[0])
field := string(args[1]) field := string(args[1])
// get entity // get entity
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if dict == nil { if dict == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
_, exists := dict.Get(field) _, exists := dict.Get(field)
if exists { if exists {
return reply.MakeIntReply(1) return reply.MakeIntReply(1)
} }
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
func HDel(db *DB, args [][]byte) redis.Reply { func HDel(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command")
} }
key := string(args[0]) key := string(args[0])
fields := make([]string, len(args) - 1) fields := make([]string, len(args)-1)
fieldArgs := args[1:] fieldArgs := args[1:]
for i, v := range fieldArgs { for i, v := range fieldArgs {
fields[i] = string(v) fields[i] = string(v)
} }
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get entity // get entity
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if dict == nil { if dict == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
deleted := 0 deleted := 0
for _, field := range fields { for _, field := range fields {
result := dict.Remove(field) result := dict.Remove(field)
deleted += result deleted += result
} }
if dict.Len() == 0 { if dict.Len() == 0 {
db.Remove(key) db.Remove(key)
} }
if deleted > 0 { if deleted > 0 {
db.AddAof(makeAofCmd("hdel", args)) db.AddAof(makeAofCmd("hdel", args))
} }
return reply.MakeIntReply(int64(deleted)) return reply.MakeIntReply(int64(deleted))
} }
func HLen(db *DB, args [][]byte) redis.Reply { func HLen(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command")
} }
key := string(args[0]) key := string(args[0])
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if dict == nil { if dict == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
return reply.MakeIntReply(int64(dict.Len())) return reply.MakeIntReply(int64(dict.Len()))
} }
func HMSet(db *DB, args [][]byte) redis.Reply { func HMSet(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 3 || len(args) % 2 != 1 { if len(args) < 3 || len(args)%2 != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command")
} }
key := string(args[0]) key := string(args[0])
size := (len(args) - 1) / 2 size := (len(args) - 1) / 2
fields := make([]string, size) fields := make([]string, size)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
fields[i] = string(args[2 * i + 1]) fields[i] = string(args[2*i+1])
values[i] = args[2 * i + 2] values[i] = args[2*i+2]
} }
// lock key // lock key
db.Locker.Lock(key) db.Locker.Lock(key)
defer db.Locker.UnLock(key) defer db.Locker.UnLock(key)
// get or init entity // get or init entity
dict, _, errReply := db.getOrInitDict(key) dict, _, errReply := db.getOrInitDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
// put data // put data
for i, field := range fields { for i, field := range fields {
value := values[i] value := values[i]
dict.Put(field, value) dict.Put(field, value)
} }
db.AddAof(makeAofCmd("hmset", args)) db.AddAof(makeAofCmd("hmset", args))
return &reply.OkReply{} return &reply.OkReply{}
} }
func HMGet(db *DB, args [][]byte) redis.Reply { func HMGet(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command")
} }
key := string(args[0]) key := string(args[0])
size := len(args) - 1 size := len(args) - 1
fields := make([]string, size) fields := make([]string, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
fields[i] = string(args[i + 1]) fields[i] = string(args[i+1])
} }
db.RLock(key) db.RLock(key)
defer db.RUnLock(key) defer db.RUnLock(key)
// get entity // get entity
result := make([][]byte, size) result := make([][]byte, size)
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if dict == nil { if dict == nil {
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} }
for i, field := range fields { for i, field := range fields {
value, ok := dict.Get(field) value, ok := dict.Get(field)
if !ok { if !ok {
result[i] = nil result[i] = nil
} else { } else {
bytes, _ := value.([]byte) bytes, _ := value.([]byte)
result[i] = bytes result[i] = bytes
} }
} }
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} }
func HKeys(db *DB, args [][]byte) redis.Reply { func HKeys(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command")
} }
key := string(args[0]) key := string(args[0])
db.RLock(key) db.RLock(key)
defer db.RUnLock(key) defer db.RUnLock(key)
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if dict == nil { if dict == nil {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
fields := make([][]byte, dict.Len()) fields := make([][]byte, dict.Len())
i := 0 i := 0
dict.ForEach(func(key string, val interface{})bool { dict.ForEach(func(key string, val interface{}) bool {
fields[i] = []byte(key) fields[i] = []byte(key)
i++ i++
return true return true
}) })
return reply.MakeMultiBulkReply(fields[:i]) return reply.MakeMultiBulkReply(fields[:i])
} }
func HVals(db *DB, args [][]byte) redis.Reply { func HVals(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command")
} }
key := string(args[0]) key := string(args[0])
db.RLock(key) db.RLock(key)
defer db.RUnLock(key) defer db.RUnLock(key)
// get entity // get entity
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if dict == nil { if dict == nil {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
values := make([][]byte, dict.Len()) values := make([][]byte, dict.Len())
i := 0 i := 0
dict.ForEach(func(key string, val interface{})bool { dict.ForEach(func(key string, val interface{}) bool {
values[i], _ = val.([]byte) values[i], _ = val.([]byte)
i++ i++
return true return true
}) })
return reply.MakeMultiBulkReply(values[:i]) return reply.MakeMultiBulkReply(values[:i])
} }
func HGetAll(db *DB, args [][]byte) redis.Reply { func HGetAll(db *DB, args [][]byte) redis.Reply {
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command")
} }
key := string(args[0]) key := string(args[0])
db.RLock(key) db.RLock(key)
defer db.RUnLock(key) defer db.RUnLock(key)
// get entity // get entity
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if dict == nil { if dict == nil {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
size := dict.Len() size := dict.Len()
result := make([][]byte, size * 2) result := make([][]byte, size*2)
i := 0 i := 0
dict.ForEach(func(key string, val interface{})bool { dict.ForEach(func(key string, val interface{}) bool {
result[i] = []byte(key) result[i] = []byte(key)
i++ i++
result[i], _ = val.([]byte) result[i], _ = val.([]byte)
i++ i++
return true return true
}) })
return reply.MakeMultiBulkReply(result[:i]) return reply.MakeMultiBulkReply(result[:i])
} }
func HIncrBy(db *DB, args [][]byte) redis.Reply { func HIncrBy(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command")
} }
key := string(args[0]) key := string(args[0])
field := string(args[1]) field := string(args[1])
rawDelta := string(args[2]) rawDelta := string(args[2])
delta, err := strconv.ParseInt(rawDelta, 10, 64) delta, err := strconv.ParseInt(rawDelta, 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
db.Locker.Lock(key) db.Locker.Lock(key)
defer db.Locker.UnLock(key) defer db.Locker.UnLock(key)
dict, _, errReply := db.getOrInitDict(key) dict, _, errReply := db.getOrInitDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
value, exists := dict.Get(field) value, exists := dict.Get(field)
if !exists { if !exists {
dict.Put(field, args[2]) dict.Put(field, args[2])
db.AddAof(makeAofCmd("hincrby", args)) db.AddAof(makeAofCmd("hincrby", args))
return reply.MakeBulkReply(args[2]) return reply.MakeBulkReply(args[2])
} else { } else {
val, err := strconv.ParseInt(string(value.([]byte)), 10, 64) val, err := strconv.ParseInt(string(value.([]byte)), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR hash value is not an integer") return reply.MakeErrReply("ERR hash value is not an integer")
} }
val += delta val += delta
bytes := []byte(strconv.FormatInt(val, 10)) bytes := []byte(strconv.FormatInt(val, 10))
dict.Put(field, bytes) dict.Put(field, bytes)
db.AddAof(makeAofCmd("hincrby", args)) db.AddAof(makeAofCmd("hincrby", args))
return reply.MakeBulkReply(bytes) return reply.MakeBulkReply(bytes)
} }
} }
func HIncrByFloat(db *DB, args [][]byte) redis.Reply { func HIncrByFloat(db *DB, args [][]byte) redis.Reply {
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command")
} }
key := string(args[0]) key := string(args[0])
field := string(args[1]) field := string(args[1])
rawDelta := string(args[2]) rawDelta := string(args[2])
delta, err := decimal.NewFromString(rawDelta) delta, err := decimal.NewFromString(rawDelta)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get or init entity // get or init entity
dict, _, errReply := db.getOrInitDict(key) dict, _, errReply := db.getOrInitDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
value, exists := dict.Get(field) value, exists := dict.Get(field)
if !exists { if !exists {
dict.Put(field, args[2]) dict.Put(field, args[2])
return reply.MakeBulkReply(args[2]) return reply.MakeBulkReply(args[2])
} else { } else {
val, err := decimal.NewFromString(string(value.([]byte))) val, err := decimal.NewFromString(string(value.([]byte)))
if err != nil { if err != nil {
return reply.MakeErrReply("ERR hash value is not a float") return reply.MakeErrReply("ERR hash value is not a float")
} }
result := val.Add(delta) result := val.Add(delta)
resultBytes:= []byte(result.String()) resultBytes := []byte(result.String())
dict.Put(field, resultBytes) dict.Put(field, resultBytes)
db.AddAof(makeAofCmd("hincrbyfloat", args)) db.AddAof(makeAofCmd("hincrbyfloat", args))
return reply.MakeBulkReply(resultBytes) return reply.MakeBulkReply(resultBytes)
} }
} }

View File

@@ -1,439 +1,439 @@
package db package db
import ( import (
List "github.com/HDT3213/godis/src/datastruct/list" List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"strconv" "strconv"
) )
func (db *DB) getAsList(key string)(*List.LinkedList, reply.ErrorReply) { func (db *DB) getAsList(key string) (*List.LinkedList, reply.ErrorReply) {
entity, ok := db.Get(key) entity, ok := db.Get(key)
if !ok { if !ok {
return nil, nil return nil, nil
} }
bytes, ok := entity.Data.(*List.LinkedList) bytes, ok := entity.Data.(*List.LinkedList)
if !ok { if !ok {
return nil, &reply.WrongTypeErrReply{} return nil, &reply.WrongTypeErrReply{}
} }
return bytes, nil return bytes, nil
} }
func (db *DB) getOrInitList(key string)(list *List.LinkedList, inited bool, errReply reply.ErrorReply) { func (db *DB) getOrInitList(key string) (list *List.LinkedList, inited bool, errReply reply.ErrorReply) {
list, errReply = db.getAsList(key) list, errReply = db.getAsList(key)
if errReply != nil { if errReply != nil {
return nil, false, errReply return nil, false, errReply
} }
inited = false inited = false
if list == nil { if list == nil {
list = &List.LinkedList{} list = &List.LinkedList{}
db.Put(key, &DataEntity{ db.Put(key, &DataEntity{
Data: list, Data: list,
}) })
inited = true inited = true
} }
return list, inited, nil return list, inited, nil
} }
func LIndex(db *DB, args [][]byte) redis.Reply { func LIndex(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
} }
key := string(args[0]) key := string(args[0])
index64, err := strconv.ParseInt(string(args[1]), 10, 64) index64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
index := int(index64) index := int(index64)
// get entity // get entity
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
size := list.Len() // assert: size > 0 size := list.Len() // assert: size > 0
if index < -1 * size { if index < -1*size {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} else if index < 0 { } else if index < 0 {
index = size + index index = size + index
} else if index >= size { } else if index >= size {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
val, _ := list.Get(index).([]byte) val, _ := list.Get(index).([]byte)
return reply.MakeBulkReply(val) return reply.MakeBulkReply(val)
} }
func LLen(db *DB, args [][]byte) redis.Reply { func LLen(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command") return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command")
} }
key := string(args[0]) key := string(args[0])
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
size := int64(list.Len()) size := int64(list.Len())
return reply.MakeIntReply(size) return reply.MakeIntReply(size)
} }
func LPop(db *DB, args [][]byte) redis.Reply { func LPop(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
} }
key := string(args[0]) key := string(args[0])
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get data // get data
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
val, _ := list.Remove(0).([]byte) val, _ := list.Remove(0).([]byte)
if list.Len() == 0 { if list.Len() == 0 {
db.Remove(key) db.Remove(key)
} }
db.AddAof(makeAofCmd("lpop", args)) db.AddAof(makeAofCmd("lpop", args))
return reply.MakeBulkReply(val) return reply.MakeBulkReply(val)
} }
func LPush(db *DB, args [][]byte) redis.Reply { func LPush(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command") return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command")
} }
key := string(args[0]) key := string(args[0])
values := args[1:] values := args[1:]
// lock // lock
db.Locker.Lock(key) db.Locker.Lock(key)
defer db.Locker.UnLock(key) defer db.Locker.UnLock(key)
// get or init entity // get or init entity
list, _, errReply := db.getOrInitList(key) list, _, errReply := db.getOrInitList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
// insert // insert
for _, value := range values { for _, value := range values {
list.Insert(0, value) list.Insert(0, value)
} }
db.AddAof(makeAofCmd("lpush", args)) db.AddAof(makeAofCmd("lpush", args))
return reply.MakeIntReply(int64(list.Len())) return reply.MakeIntReply(int64(list.Len()))
} }
func LPushX(db *DB, args [][]byte) redis.Reply { func LPushX(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command") return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command")
} }
key := string(args[0]) key := string(args[0])
values := args[1:] values := args[1:]
// lock // lock
db.Locker.Lock(key) db.Locker.Lock(key)
defer db.Locker.UnLock(key) defer db.Locker.UnLock(key)
// get or init entity // get or init entity
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
// insert // insert
for _, value := range values { for _, value := range values {
list.Insert(0, value) list.Insert(0, value)
} }
db.AddAof(makeAofCmd("lpushx", args)) db.AddAof(makeAofCmd("lpushx", args))
return reply.MakeIntReply(int64(list.Len())) return reply.MakeIntReply(int64(list.Len()))
} }
func LRange(db *DB, args [][]byte) redis.Reply { func LRange(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command") return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command")
} }
key := string(args[0]) key := string(args[0])
start64, err := strconv.ParseInt(string(args[1]), 10, 64) start64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
start := int(start64) start := int(start64)
stop64, err := strconv.ParseInt(string(args[2]), 10, 64) stop64, err := strconv.ParseInt(string(args[2]), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
stop := int(stop64) stop := int(stop64)
// lock key // lock key
db.RLock(key) db.RLock(key)
defer db.RUnLock(key) defer db.RUnLock(key)
// get data // get data
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
// compute index // compute index
size := list.Len() // assert: size > 0 size := list.Len() // assert: size > 0
if start < -1 * size { if start < -1*size {
start = 0 start = 0
} else if start < 0 { } else if start < 0 {
start = size + start start = size + start
} else if start >= size { } else if start >= size {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
if stop < -1 * size { if stop < -1*size {
stop = 0 stop = 0
} else if stop < 0 { } else if stop < 0 {
stop = size + stop + 1 stop = size + stop + 1
} else if stop < size { } else if stop < size {
stop = stop + 1 stop = stop + 1
} else { } else {
stop = size stop = size
} }
if stop < start { if stop < start {
stop = start stop = start
} }
// assert: start in [0, size - 1], stop in [start, size] // assert: start in [0, size - 1], stop in [start, size]
slice := list.Range(start, stop) slice := list.Range(start, stop)
result := make([][]byte, len(slice)) result := make([][]byte, len(slice))
for i, raw := range slice { for i, raw := range slice {
bytes, _ := raw.([]byte) bytes, _ := raw.([]byte)
result[i] = bytes result[i] = bytes
} }
return reply.MakeMultiBulkReply(result) return reply.MakeMultiBulkReply(result)
} }
func LRem(db *DB, args [][]byte) redis.Reply { func LRem(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command") return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command")
} }
key := string(args[0]) key := string(args[0])
count64, err := strconv.ParseInt(string(args[1]), 10, 64) count64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
count := int(count64) count := int(count64)
value := args[2] value := args[2]
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get data entity // get data entity
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
var removed int var removed int
if count == 0 { if count == 0 {
removed = list.RemoveAllByVal(value) removed = list.RemoveAllByVal(value)
} else if count > 0 { } else if count > 0 {
removed = list.RemoveByVal(value, count) removed = list.RemoveByVal(value, count)
} else { } else {
removed = list.ReverseRemoveByVal(value, -count) removed = list.ReverseRemoveByVal(value, -count)
} }
if list.Len() == 0 { if list.Len() == 0 {
db.Remove(key) db.Remove(key)
} }
if removed > 0 { if removed > 0 {
db.AddAof(makeAofCmd("lrem", args)) db.AddAof(makeAofCmd("lrem", args))
} }
return reply.MakeIntReply(int64(removed)) return reply.MakeIntReply(int64(removed))
} }
func LSet(db *DB, args [][]byte) redis.Reply { func LSet(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 3 { if len(args) != 3 {
return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command")
} }
key := string(args[0]) key := string(args[0])
index64, err := strconv.ParseInt(string(args[1]), 10, 64) index64, err := strconv.ParseInt(string(args[1]), 10, 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not an integer or out of range") return reply.MakeErrReply("ERR value is not an integer or out of range")
} }
index := int(index64) index := int(index64)
value := args[2] value := args[2]
// lock // lock
db.Locker.Lock(key) db.Locker.Lock(key)
defer db.Locker.UnLock(key) defer db.Locker.UnLock(key)
// get data // get data
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return reply.MakeErrReply("ERR no such key") return reply.MakeErrReply("ERR no such key")
} }
size := list.Len() // assert: size > 0 size := list.Len() // assert: size > 0
if index < -1 * size { if index < -1*size {
return reply.MakeErrReply("ERR index out of range") return reply.MakeErrReply("ERR index out of range")
} else if index < 0 { } else if index < 0 {
index = size + index index = size + index
} else if index >= size { } else if index >= size {
return reply.MakeErrReply("ERR index out of range") return reply.MakeErrReply("ERR index out of range")
} }
list.Set(index, value) list.Set(index, value)
db.AddAof(makeAofCmd("lset", args)) db.AddAof(makeAofCmd("lset", args))
return &reply.OkReply{} return &reply.OkReply{}
} }
func RPop(db *DB, args [][]byte) redis.Reply { func RPop(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) != 1 { if len(args) != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command") return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command")
} }
key := string(args[0]) key := string(args[0])
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get data // get data
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
val, _ := list.RemoveLast().([]byte) val, _ := list.RemoveLast().([]byte)
if list.Len() == 0 { if list.Len() == 0 {
db.Remove(key) db.Remove(key)
} }
db.AddAof(makeAofCmd("rpop", args)) db.AddAof(makeAofCmd("rpop", args))
return reply.MakeBulkReply(val) return reply.MakeBulkReply(val)
} }
func RPopLPush(db *DB, args [][]byte) redis.Reply { func RPopLPush(db *DB, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command") return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command")
} }
sourceKey := string(args[0]) sourceKey := string(args[0])
destKey := string(args[1]) destKey := string(args[1])
// lock // lock
db.Locks(sourceKey, destKey) db.Locks(sourceKey, destKey)
defer db.UnLocks(sourceKey, destKey) defer db.UnLocks(sourceKey, destKey)
// get source entity // get source entity
sourceList, errReply := db.getAsList(sourceKey) sourceList, errReply := db.getAsList(sourceKey)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if sourceList == nil { if sourceList == nil {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} }
// get dest entity // get dest entity
destList, _, errReply := db.getOrInitList(destKey) destList, _, errReply := db.getOrInitList(destKey)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
// pop and push // pop and push
val, _ := sourceList.RemoveLast().([]byte) val, _ := sourceList.RemoveLast().([]byte)
destList.Insert(0, val) destList.Insert(0, val)
if sourceList.Len() == 0 { if sourceList.Len() == 0 {
db.Remove(sourceKey) db.Remove(sourceKey)
} }
db.AddAof(makeAofCmd("rpoplpush", args)) db.AddAof(makeAofCmd("rpoplpush", args))
return reply.MakeBulkReply(val) return reply.MakeBulkReply(val)
} }
func RPush(db *DB, args [][]byte) redis.Reply { func RPush(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
} }
key := string(args[0]) key := string(args[0])
values := args[1:] values := args[1:]
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get or init entity // get or init entity
list, _, errReply := db.getOrInitList(key) list, _, errReply := db.getOrInitList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
// put list // put list
for _, value := range values { for _, value := range values {
list.Add(value) list.Add(value)
} }
db.AddAof(makeAofCmd("rpush", args)) db.AddAof(makeAofCmd("rpush", args))
return reply.MakeIntReply(int64(list.Len())) return reply.MakeIntReply(int64(list.Len()))
} }
func RPushX(db *DB, args [][]byte) redis.Reply { func RPushX(db *DB, args [][]byte) redis.Reply {
if len(args) < 2 { if len(args) < 2 {
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
} }
key := string(args[0]) key := string(args[0])
values := args[1:] values := args[1:]
// lock // lock
db.Lock(key) db.Lock(key)
defer db.UnLock(key) defer db.UnLock(key)
// get or init entity // get or init entity
list, errReply := db.getAsList(key) list, errReply := db.getAsList(key)
if errReply != nil { if errReply != nil {
return errReply return errReply
} }
if list == nil { if list == nil {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
// put list // put list
for _, value := range values { for _, value := range values {
list.Add(value) list.Add(value)
} }
db.AddAof(makeAofCmd("rpushx", args)) db.AddAof(makeAofCmd("rpushx", args))
return reply.MakeIntReply(int64(list.Len())) return reply.MakeIntReply(int64(list.Len()))
} }

View File

@@ -1,102 +1,102 @@
package db package db
func MakeRouter()map[string]CmdFunc { func MakeRouter() map[string]CmdFunc {
routerMap := make(map[string]CmdFunc) routerMap := make(map[string]CmdFunc)
routerMap["ping"] = Ping routerMap["ping"] = Ping
routerMap["del"] = Del routerMap["del"] = Del
routerMap["expire"] = Expire routerMap["expire"] = Expire
routerMap["expireat"] = ExpireAt routerMap["expireat"] = ExpireAt
routerMap["pexpire"] = PExpire routerMap["pexpire"] = PExpire
routerMap["pexpireat"] = PExpireAt routerMap["pexpireat"] = PExpireAt
routerMap["ttl"] = TTL routerMap["ttl"] = TTL
routerMap["pttl"] = PTTL routerMap["pttl"] = PTTL
routerMap["persist"] = Persist routerMap["persist"] = Persist
routerMap["exists"] = Exists routerMap["exists"] = Exists
routerMap["type"] = Type routerMap["type"] = Type
routerMap["rename"] = Rename routerMap["rename"] = Rename
routerMap["renamenx"] = RenameNx routerMap["renamenx"] = RenameNx
routerMap["set"] = Set routerMap["set"] = Set
routerMap["setnx"] = SetNX routerMap["setnx"] = SetNX
routerMap["setex"] = SetEX routerMap["setex"] = SetEX
routerMap["psetex"] = PSetEX routerMap["psetex"] = PSetEX
routerMap["mset"] = MSet routerMap["mset"] = MSet
routerMap["mget"] = MGet routerMap["mget"] = MGet
routerMap["msetnx"] = MSetNX routerMap["msetnx"] = MSetNX
routerMap["get"] = Get routerMap["get"] = Get
routerMap["getset"] = GetSet routerMap["getset"] = GetSet
routerMap["incr"] = Incr routerMap["incr"] = Incr
routerMap["incrby"] = IncrBy routerMap["incrby"] = IncrBy
routerMap["incrbyfloat"] = IncrByFloat routerMap["incrbyfloat"] = IncrByFloat
routerMap["decr"] = Decr routerMap["decr"] = Decr
routerMap["decrby"] = DecrBy routerMap["decrby"] = DecrBy
routerMap["lpush"] = LPush routerMap["lpush"] = LPush
routerMap["lpushx"] = LPushX routerMap["lpushx"] = LPushX
routerMap["rpush"] = RPush routerMap["rpush"] = RPush
routerMap["rpushx"] = RPushX routerMap["rpushx"] = RPushX
routerMap["lpop"] = LPop routerMap["lpop"] = LPop
routerMap["rpop"] = RPop routerMap["rpop"] = RPop
routerMap["rpoplpush"] = RPopLPush routerMap["rpoplpush"] = RPopLPush
routerMap["lrem"] = LRem routerMap["lrem"] = LRem
routerMap["llen"] = LLen routerMap["llen"] = LLen
routerMap["lindex"] = LIndex routerMap["lindex"] = LIndex
routerMap["lset"] = LSet routerMap["lset"] = LSet
routerMap["lrange"] = LRange routerMap["lrange"] = LRange
routerMap["hset"] = HSet routerMap["hset"] = HSet
routerMap["hsetnx"] = HSetNX routerMap["hsetnx"] = HSetNX
routerMap["hget"] = HGet routerMap["hget"] = HGet
routerMap["hexists"] = HExists routerMap["hexists"] = HExists
routerMap["hdel"] = HDel routerMap["hdel"] = HDel
routerMap["hlen"] = HLen routerMap["hlen"] = HLen
routerMap["hmget"] = HMGet routerMap["hmget"] = HMGet
routerMap["hmset"] = HMSet routerMap["hmset"] = HMSet
routerMap["hkeys"] = HKeys routerMap["hkeys"] = HKeys
routerMap["hvals"] = HVals routerMap["hvals"] = HVals
routerMap["hgetall"] = HGetAll routerMap["hgetall"] = HGetAll
routerMap["hincrby"] = HIncrBy routerMap["hincrby"] = HIncrBy
routerMap["hincrbyfloat"] = HIncrByFloat routerMap["hincrbyfloat"] = HIncrByFloat
routerMap["sadd"] = SAdd routerMap["sadd"] = SAdd
routerMap["sismember"] = SIsMember routerMap["sismember"] = SIsMember
routerMap["srem"] = SRem routerMap["srem"] = SRem
routerMap["scard"] = SCard routerMap["scard"] = SCard
routerMap["smembers"] = SMembers routerMap["smembers"] = SMembers
routerMap["sinter"] = SInter routerMap["sinter"] = SInter
routerMap["sinterstore"] = SInterStore routerMap["sinterstore"] = SInterStore
routerMap["sunion"] = SUnion routerMap["sunion"] = SUnion
routerMap["sunionstore"] = SUnionStore routerMap["sunionstore"] = SUnionStore
routerMap["sdiff"] = SDiff routerMap["sdiff"] = SDiff
routerMap["sdiffstore"] = SDiffStore routerMap["sdiffstore"] = SDiffStore
routerMap["srandmember"] = SRandMember routerMap["srandmember"] = SRandMember
routerMap["zadd"] = ZAdd routerMap["zadd"] = ZAdd
routerMap["zscore"] = ZScore routerMap["zscore"] = ZScore
routerMap["zincrby"] = ZIncrBy routerMap["zincrby"] = ZIncrBy
routerMap["zrank"] = ZRank routerMap["zrank"] = ZRank
routerMap["zcount"] = ZCount routerMap["zcount"] = ZCount
routerMap["zrevrank"] = ZRevRank routerMap["zrevrank"] = ZRevRank
routerMap["zcard"] = ZCard routerMap["zcard"] = ZCard
routerMap["zrange"] = ZRange routerMap["zrange"] = ZRange
routerMap["zrevrange"] = ZRevRange routerMap["zrevrange"] = ZRevRange
routerMap["zrangebyscore"] = ZRangeByScore routerMap["zrangebyscore"] = ZRangeByScore
routerMap["zrevrangebyscore"] = ZRevRangeByScore routerMap["zrevrangebyscore"] = ZRevRangeByScore
routerMap["zrem"] = ZRem routerMap["zrem"] = ZRem
routerMap["zremrangebyscore"] = ZRemRangeByScore routerMap["zremrangebyscore"] = ZRemRangeByScore
routerMap["zremrangebyrank"] = ZRemRangeByRank routerMap["zremrangebyrank"] = ZRemRangeByRank
routerMap["geoadd"] = GeoAdd routerMap["geoadd"] = GeoAdd
routerMap["geopos"] = GeoPos routerMap["geopos"] = GeoPos
routerMap["geodist"] = GeoDist routerMap["geodist"] = GeoDist
routerMap["geohash"] = GeoHash routerMap["geohash"] = GeoHash
routerMap["georadius"] = GeoRadius routerMap["georadius"] = GeoRadius
routerMap["georadiusbymember"] = GeoRadiusByMember routerMap["georadiusbymember"] = GeoRadiusByMember
routerMap["flushdb"] = FlushDB routerMap["flushdb"] = FlushDB
routerMap["flushall"] = FlushAll routerMap["flushall"] = FlushAll
routerMap["keys"] = Keys routerMap["keys"] = Keys
return routerMap return routerMap
} }

View File

@@ -1,16 +1,16 @@
package db package db
import ( import (
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
) )
func Ping(db *DB, args [][]byte) redis.Reply { func Ping(db *DB, args [][]byte) redis.Reply {
if len(args) == 0 { if len(args) == 0 {
return &reply.PongReply{} return &reply.PongReply{}
} else if len(args) == 1 { } else if len(args) == 1 {
return reply.MakeStatusReply("\"" + string(args[0]) + "\"") return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
} else { } else {
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ package db
import "github.com/HDT3213/godis/src/interface/redis" import "github.com/HDT3213/godis/src/interface/redis"
type DB interface { type DB interface {
Exec(client redis.Connection, args [][]byte) redis.Reply Exec(client redis.Connection, args [][]byte) redis.Reply
AfterClientClose(c redis.Connection) AfterClientClose(c redis.Connection)
Close() Close()
} }

View File

@@ -1,11 +1,11 @@
package redis package redis
type Connection interface { type Connection interface {
Write([]byte) error Write([]byte) error
// client should keep its subscribing channels // client should keep its subscribing channels
SubsChannel(channel string) SubsChannel(channel string)
UnSubsChannel(channel string) UnSubsChannel(channel string)
SubsCount()int SubsCount() int
GetChannels()[]string GetChannels() []string
} }

View File

@@ -1,5 +1,5 @@
package redis package redis
type Reply interface { type Reply interface {
ToBytes()[]byte ToBytes() []byte
} }

View File

@@ -1,13 +1,13 @@
package tcp package tcp
import ( import (
"net" "context"
"context" "net"
) )
type HandleFunc func(ctx context.Context, conn net.Conn) type HandleFunc func(ctx context.Context, conn net.Conn)
type Handler interface { type Handler interface {
Handle(ctx context.Context, conn net.Conn) Handle(ctx context.Context, conn net.Conn)
Close()error Close() error
} }

View File

@@ -1,80 +1,80 @@
package consistenthash package consistenthash
import ( import (
"hash/crc32" "hash/crc32"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
) )
type HashFunc func(data []byte) uint32 type HashFunc func(data []byte) uint32
type Map struct { type Map struct {
hashFunc HashFunc hashFunc HashFunc
replicas int replicas int
keys []int // sorted keys []int // sorted
hashMap map[int]string hashMap map[int]string
} }
func New(replicas int, fn HashFunc) *Map { func New(replicas int, fn HashFunc) *Map {
m := &Map{ m := &Map{
replicas: replicas, replicas: replicas,
hashFunc: fn, hashFunc: fn,
hashMap: make(map[int]string), hashMap: make(map[int]string),
} }
if m.hashFunc == nil { if m.hashFunc == nil {
m.hashFunc = crc32.ChecksumIEEE m.hashFunc = crc32.ChecksumIEEE
} }
return m return m
} }
func (m *Map) IsEmpty() bool { func (m *Map) IsEmpty() bool {
return len(m.keys) == 0 return len(m.keys) == 0
} }
func (m *Map) Add(keys ...string) { func (m *Map) Add(keys ...string) {
for _, key := range keys { for _, key := range keys {
if key == "" { if key == "" {
continue continue
} }
for i := 0; i < m.replicas; i++ { for i := 0; i < m.replicas; i++ {
hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key))) hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key)))
m.keys = append(m.keys, hash) m.keys = append(m.keys, hash)
m.hashMap[hash] = key m.hashMap[hash] = key
} }
} }
sort.Ints(m.keys) sort.Ints(m.keys)
} }
// support hash tag // support hash tag
func getPartitionKey(key string) string { func getPartitionKey(key string) string {
beg := strings.Index(key, "{") beg := strings.Index(key, "{")
if beg == -1 { if beg == -1 {
return key return key
} }
end := strings.Index(key, "}") end := strings.Index(key, "}")
if end == -1 || end == beg+1 { if end == -1 || end == beg+1 {
return key return key
} }
return key[beg+1 : end] return key[beg+1 : end]
} }
// Get gets the closest item in the hash to the provided key. // Get gets the closest item in the hash to the provided key.
func (m *Map) Get(key string) string { func (m *Map) Get(key string) string {
if m.IsEmpty() { if m.IsEmpty() {
return "" return ""
} }
partitionKey := getPartitionKey(key) partitionKey := getPartitionKey(key)
hash := int(m.hashFunc([]byte(partitionKey))) hash := int(m.hashFunc([]byte(partitionKey)))
// Binary search for appropriate replica. // Binary search for appropriate replica.
idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash }) idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash })
// Means we have cycled back to the first replica. // Means we have cycled back to the first replica.
if idx == len(m.keys) { if idx == len(m.keys) {
idx = 0 idx = 0
} }
return m.hashMap[m.keys[idx]] return m.hashMap[m.keys[idx]]
} }

View File

@@ -1,78 +1,78 @@
package files package files
import ( import (
"mime/multipart" "fmt"
"io/ioutil" "io/ioutil"
"path" "mime/multipart"
"os" "os"
"fmt" "path"
) )
func GetSize(f multipart.File) (int, error) { func GetSize(f multipart.File) (int, error) {
content, err := ioutil.ReadAll(f) content, err := ioutil.ReadAll(f)
return len(content), err return len(content), err
} }
func GetExt(fileName string) string { func GetExt(fileName string) string {
return path.Ext(fileName) return path.Ext(fileName)
} }
func CheckNotExist(src string) bool { func CheckNotExist(src string) bool {
_, err := os.Stat(src) _, err := os.Stat(src)
return os.IsNotExist(err) return os.IsNotExist(err)
} }
func CheckPermission(src string) bool { func CheckPermission(src string) bool {
_, err := os.Stat(src) _, err := os.Stat(src)
return os.IsPermission(err) return os.IsPermission(err)
} }
func IsNotExistMkDir(src string) error { func IsNotExistMkDir(src string) error {
if notExist := CheckNotExist(src); notExist == true { if notExist := CheckNotExist(src); notExist == true {
if err := MkDir(src); err != nil { if err := MkDir(src); err != nil {
return err return err
} }
} }
return nil return nil
} }
func MkDir(src string) error { func MkDir(src string) error {
err := os.MkdirAll(src, os.ModePerm) err := os.MkdirAll(src, os.ModePerm)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func Open(name string, flag int, perm os.FileMode) (*os.File, error) { func Open(name string, flag int, perm os.FileMode) (*os.File, error) {
f, err := os.OpenFile(name, flag, perm) f, err := os.OpenFile(name, flag, perm)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return f, nil return f, nil
} }
func MustOpen(fileName, dir string) (*os.File, error) { func MustOpen(fileName, dir string) (*os.File, error) {
perm := CheckPermission(dir) perm := CheckPermission(dir)
if perm == true { if perm == true {
return nil, fmt.Errorf("permission denied dir: %s", dir) return nil, fmt.Errorf("permission denied dir: %s", dir)
} }
err := IsNotExistMkDir(dir) err := IsNotExistMkDir(dir)
if err != nil { if err != nil {
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err) return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
} }
f, err := Open(dir + string(os.PathSeparator) + fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644) f, err := Open(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
if err != nil { if err != nil {
return nil, fmt.Errorf("fail to open file, err: %s", err) return nil, fmt.Errorf("fail to open file, err: %s", err)
} }
return f, nil return f, nil
} }

View File

@@ -1,9 +1,9 @@
package geohash package geohash
import ( import (
"bytes" "bytes"
"encoding/base32" "encoding/base32"
"encoding/binary" "encoding/binary"
) )
var bits = []uint8{128, 64, 32, 16, 8, 4, 2, 1} var bits = []uint8{128, 64, 32, 16, 8, 4, 2, 1}
@@ -13,95 +13,95 @@ const defaultBitSize = 64 // 32 bits for latitude, another 32 bits for longitude
// return: geohash, box // return: geohash, box
func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64) { func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64) {
box := [2][2]float64{ box := [2][2]float64{
{-180, 180}, // lng {-180, 180}, // lng
{-90, 90}, // lat {-90, 90}, // lat
} }
pos := [2]float64{longitude, latitude} pos := [2]float64{longitude, latitude}
hash := &bytes.Buffer{} hash := &bytes.Buffer{}
bit := 0 bit := 0
var precision uint = 0 var precision uint = 0
code := uint8(0) code := uint8(0)
for precision < bitSize { for precision < bitSize {
for direction, val := range pos { for direction, val := range pos {
mid := (box[direction][0] + box[direction][1]) / 2 mid := (box[direction][0] + box[direction][1]) / 2
if val < mid { if val < mid {
box[direction][1] = mid box[direction][1] = mid
} else { } else {
box[direction][0] = mid box[direction][0] = mid
code |= bits[bit] code |= bits[bit]
} }
bit++ bit++
if bit == 8 { if bit == 8 {
hash.WriteByte(code) hash.WriteByte(code)
bit = 0 bit = 0
code = 0 code = 0
} }
precision++ precision++
if precision == bitSize { if precision == bitSize {
break break
} }
} }
} }
// precision%8 > 0 // precision%8 > 0
if code > 0 { if code > 0 {
hash.WriteByte(code) hash.WriteByte(code)
} }
return hash.Bytes(), box return hash.Bytes(), box
} }
func Encode(latitude, longitude float64) uint64 { func Encode(latitude, longitude float64) uint64 {
buf, _ := encode0(latitude, longitude, defaultBitSize) buf, _ := encode0(latitude, longitude, defaultBitSize)
return binary.BigEndian.Uint64(buf) return binary.BigEndian.Uint64(buf)
} }
func decode0(hash []byte) [][]float64 { func decode0(hash []byte) [][]float64 {
box := [][]float64{ box := [][]float64{
{-180, 180}, {-180, 180},
{-90, 90}, {-90, 90},
} }
direction := 0 direction := 0
for i := 0; i < len(hash); i++ { for i := 0; i < len(hash); i++ {
code := hash[i] code := hash[i]
for j := 0; j < len(bits); j++ { for j := 0; j < len(bits); j++ {
mid := (box[direction][0] + box[direction][1]) / 2 mid := (box[direction][0] + box[direction][1]) / 2
mask := bits[j] mask := bits[j]
if mask&code > 0 { if mask&code > 0 {
box[direction][0] = mid box[direction][0] = mid
} else { } else {
box[direction][1] = mid box[direction][1] = mid
} }
direction = (direction + 1) % 2 direction = (direction + 1) % 2
} }
} }
return box return box
} }
func Decode(code uint64) (float64, float64) { func Decode(code uint64) (float64, float64) {
buf := make([]byte, 8) buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, code) binary.BigEndian.PutUint64(buf, code)
box := decode0(buf) box := decode0(buf)
lng := float64(box[0][0]+box[0][1]) / 2 lng := float64(box[0][0]+box[0][1]) / 2
lat := float64(box[1][0]+box[1][1]) / 2 lat := float64(box[1][0]+box[1][1]) / 2
return lat, lng return lat, lng
} }
func ToString(buf []byte) string { func ToString(buf []byte) string {
return enc.EncodeToString(buf) return enc.EncodeToString(buf)
} }
func ToInt(buf []byte) uint64 { func ToInt(buf []byte) uint64 {
// padding // padding
if len(buf) < 8 { if len(buf) < 8 {
buf2 := make([]byte, 8) buf2 := make([]byte, 8)
copy(buf2, buf) copy(buf2, buf)
return binary.BigEndian.Uint64(buf2) return binary.BigEndian.Uint64(buf2)
} }
return binary.BigEndian.Uint64(buf) return binary.BigEndian.Uint64(buf)
} }
func FromInt(code uint64) []byte { func FromInt(code uint64) []byte {
buf := make([]byte, 8) buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, code) binary.BigEndian.PutUint64(buf, code)
return buf return buf
} }

View File

@@ -1,39 +1,39 @@
package geohash package geohash
import ( import (
"fmt" "fmt"
"math" "math"
"testing" "testing"
) )
func TestToRange(t *testing.T) { func TestToRange(t *testing.T) {
neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00} neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}
range_ := ToRange(neighbor, 36) range_ := ToRange(neighbor, 36)
expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}) expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00})
expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00}) expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00})
if expectedLower != range_[0] { if expectedLower != range_[0] {
t.Error("incorrect lower") t.Error("incorrect lower")
} }
if expectedUpper != range_[1] { if expectedUpper != range_[1] {
t.Error("incorrect upper") t.Error("incorrect upper")
} }
} }
func TestEncode(t *testing.T) { func TestEncode(t *testing.T) {
lat0 := 48.669 lat0 := 48.669
lng0 := -4.32913 lng0 := -4.32913
hash := Encode(lat0, lng0) hash := Encode(lat0, lng0)
str := ToString(FromInt(hash)) str := ToString(FromInt(hash))
if str != "gbsuv7zt7zntw" { if str != "gbsuv7zt7zntw" {
t.Error("encode error") t.Error("encode error")
} }
lat, lng := Decode(hash) lat, lng := Decode(hash)
if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 { if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 {
t.Error("decode error") t.Error("decode error")
} }
} }
func TestGetNeighbours(t *testing.T) { func TestGetNeighbours(t *testing.T) {
ranges := GetNeighbours(90, 180, 630*1000) ranges := GetNeighbours(90, 180, 630*1000)
fmt.Printf("%#v", ranges) fmt.Printf("%#v", ranges)
} }

View File

@@ -3,134 +3,134 @@ package geohash
import "math" import "math"
const ( const (
DR = math.Pi / 180.0 DR = math.Pi / 180.0
EarthRadius = 6372797.560856 EarthRadius = 6372797.560856
MercatorMax = 20037726.37 // pi * EarthRadius MercatorMax = 20037726.37 // pi * EarthRadius
MercatorMin = -20037726.37 MercatorMin = -20037726.37
) )
func degRad(ang float64) float64 { func degRad(ang float64) float64 {
return ang * DR return ang * DR
} }
func radDeg(ang float64) float64 { func radDeg(ang float64) float64 {
return ang / DR return ang / DR
} }
func getBoundingBox(latitude float64, longitude float64, radiusMeters float64) ( func getBoundingBox(latitude float64, longitude float64, radiusMeters float64) (
minLat, maxLat, minLng, maxLng float64) { minLat, maxLat, minLng, maxLng float64) {
minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude))) minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
if minLng < -180 { if minLng < -180 {
minLng = -180 minLng = -180
} }
maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude))) maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
if maxLng > 180 { if maxLng > 180 {
maxLng = 180 maxLng = 180
} }
minLat = latitude - radDeg(radiusMeters/EarthRadius) minLat = latitude - radDeg(radiusMeters/EarthRadius)
if minLat < -90 { if minLat < -90 {
minLat = -90 minLat = -90
} }
maxLat = latitude + radDeg(radiusMeters/EarthRadius) maxLat = latitude + radDeg(radiusMeters/EarthRadius)
if maxLat > 90 { if maxLat > 90 {
maxLat = 90 maxLat = 90
} }
return return
} }
func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint { func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint {
if radiusMeters == 0 { if radiusMeters == 0 {
return defaultBitSize - 1 return defaultBitSize - 1
} }
var precision uint = 1 var precision uint = 1
for radiusMeters < MercatorMax { for radiusMeters < MercatorMax {
radiusMeters *= 2 radiusMeters *= 2
precision++ precision++
} }
/* Make sure range is included in most of the base cases. */ /* Make sure range is included in most of the base cases. */
precision -= 2 precision -= 2
if latitude > 66 || latitude < -66 { if latitude > 66 || latitude < -66 {
precision-- precision--
if latitude > 80 || latitude < -80 { if latitude > 80 || latitude < -80 {
precision-- precision--
} }
} }
if precision < 1 { if precision < 1 {
precision = 1 precision = 1
} }
if precision > 32 { if precision > 32 {
precision = 32 precision = 32
} }
return precision*2 - 1 return precision*2 - 1
} }
func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 { func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 {
radLat1 := degRad(latitude1) radLat1 := degRad(latitude1)
radLat2 := degRad(latitude2) radLat2 := degRad(latitude2)
a := radLat1 - radLat2 a := radLat1 - radLat2
b := degRad(longitude1) - degRad(longitude2) b := degRad(longitude1) - degRad(longitude2)
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2) + return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2)+
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2))) math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
} }
func ToRange(scope []byte, precision uint) [2]uint64 { func ToRange(scope []byte, precision uint) [2]uint64 {
lower := ToInt(scope) lower := ToInt(scope)
radius := uint64(1 << (64 - precision)) radius := uint64(1 << (64 - precision))
upper := lower + radius upper := lower + radius
return [2]uint64{lower, upper} return [2]uint64{lower, upper}
} }
func ensureValidLat(lat float64) float64 { func ensureValidLat(lat float64) float64 {
if lat > 90 { if lat > 90 {
return 90 return 90
} }
if lat < -90 { if lat < -90 {
return -90 return -90
} }
return lat return lat
} }
func ensureValidLng(lng float64) float64 { func ensureValidLng(lng float64) float64 {
if lng > 180 { if lng > 180 {
return -360 + lng return -360 + lng
} }
if lng < -180 { if lng < -180 {
return 360 + lng return 360 + lng
} }
return lng return lng
} }
func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 { func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 {
precision := estimatePrecisionByRadius(radiusMeters, latitude) precision := estimatePrecisionByRadius(radiusMeters, latitude)
center, box := encode0(latitude, longitude, precision) center, box := encode0(latitude, longitude, precision)
height := box[0][1] - box[0][0] height := box[0][1] - box[0][0]
width := box[1][1] - box[1][0] width := box[1][1] - box[1][0]
centerLng := (box[0][1] + box[0][0]) / 2 centerLng := (box[0][1] + box[0][0]) / 2
centerLat := (box[1][1] + box[1][0]) / 2 centerLat := (box[1][1] + box[1][0]) / 2
maxLat := ensureValidLat(centerLat + height) maxLat := ensureValidLat(centerLat + height)
minLat := ensureValidLat(centerLat - height) minLat := ensureValidLat(centerLat - height)
maxLng := ensureValidLng(centerLng + width) maxLng := ensureValidLng(centerLng + width)
minLng := ensureValidLng(centerLng - width) minLng := ensureValidLng(centerLng - width)
var result [10][2]uint64 var result [10][2]uint64
leftUpper, _ := encode0(maxLat, minLng, precision) leftUpper, _ := encode0(maxLat, minLng, precision)
result[1] = ToRange(leftUpper, precision) result[1] = ToRange(leftUpper, precision)
upper, _ := encode0(maxLat, centerLng, precision) upper, _ := encode0(maxLat, centerLng, precision)
result[2] = ToRange(upper, precision) result[2] = ToRange(upper, precision)
rightUpper, _ := encode0(maxLat, maxLng, precision) rightUpper, _ := encode0(maxLat, maxLng, precision)
result[3] = ToRange(rightUpper, precision) result[3] = ToRange(rightUpper, precision)
left, _ := encode0(centerLat, minLng, precision) left, _ := encode0(centerLat, minLng, precision)
result[4] = ToRange(left, precision) result[4] = ToRange(left, precision)
result[5] = ToRange(center, precision) result[5] = ToRange(center, precision)
right, _ := encode0(centerLat, maxLng, precision) right, _ := encode0(centerLat, maxLng, precision)
result[6] = ToRange(right, precision) result[6] = ToRange(right, precision)
leftDown, _ := encode0(minLat, minLng, precision) leftDown, _ := encode0(minLat, minLng, precision)
result[7] = ToRange(leftDown, precision) result[7] = ToRange(leftDown, precision)
down, _ := encode0(minLat, centerLng, precision) down, _ := encode0(minLat, centerLng, precision)
result[8] = ToRange(down, precision) result[8] = ToRange(down, precision)
rightDown, _ := encode0(minLat, maxLng, precision) rightDown, _ := encode0(minLat, maxLng, precision)
result[9] = ToRange(rightDown, precision) result[9] = ToRange(rightDown, precision)
return result[1:] return result[1:]
} }

View File

@@ -1,21 +1,21 @@
package gob package gob
import ( import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
) )
func Marshal(obj interface{}) ([]byte, error) { func Marshal(obj interface{}) ([]byte, error) {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf) enc := gob.NewEncoder(buf)
err := enc.Encode(obj) err := enc.Encode(obj)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func UnMarshal(data []byte, obj interface{}) error { func UnMarshal(data []byte, obj interface{}) error {
dec := gob.NewDecoder(bytes.NewBuffer(data)) dec := gob.NewDecoder(bytes.NewBuffer(data))
return dec.Decode(obj) return dec.Decode(obj)
} }

View File

@@ -4,16 +4,14 @@ import "sync/atomic"
type AtomicBool uint32 type AtomicBool uint32
func (b *AtomicBool)Get()bool { func (b *AtomicBool) Get() bool {
return atomic.LoadUint32((*uint32)(b)) != 0 return atomic.LoadUint32((*uint32)(b)) != 0
} }
func (b *AtomicBool)Set(v bool) { func (b *AtomicBool) Set(v bool) {
if v { if v {
atomic.StoreUint32((*uint32)(b), 1) atomic.StoreUint32((*uint32)(b), 1)
} else { } else {
atomic.StoreUint32((*uint32)(b), 0) atomic.StoreUint32((*uint32)(b), 0)
} }
} }

View File

@@ -1,38 +1,38 @@
package wait package wait
import ( import (
"sync" "sync"
"time" "time"
) )
type Wait struct { type Wait struct {
wg sync.WaitGroup wg sync.WaitGroup
} }
func (w *Wait)Add(delta int) { func (w *Wait) Add(delta int) {
w.wg.Add(delta) w.wg.Add(delta)
} }
func (w *Wait)Done() { func (w *Wait) Done() {
w.wg.Done() w.wg.Done()
} }
func (w *Wait)Wait() { func (w *Wait) Wait() {
w.wg.Wait() w.wg.Wait()
} }
// return isTimeout // return isTimeout
func (w *Wait)WaitWithTimeout(timeout time.Duration)bool { func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
c := make(chan bool) c := make(chan bool)
go func() { go func() {
defer close(c) defer close(c)
w.wg.Wait() w.wg.Wait()
c <- true c <- true
}() }()
select { select {
case <-c: case <-c:
return false // completed normally return false // completed normally
case <-time.After(timeout): case <-time.After(timeout):
return true // timed out return true // timed out
} }
} }

View File

@@ -1,95 +1,95 @@
package wildcard package wildcard
const ( const (
normal = iota normal = iota
all // * all // *
any // ? any // ?
set_ // [] set_ // []
) )
type item struct { type item struct {
character byte character byte
set map[byte]bool set map[byte]bool
typeCode int typeCode int
} }
func (i *item) contains(c byte) bool { func (i *item) contains(c byte) bool {
_, ok := i.set[c] _, ok := i.set[c]
return ok return ok
} }
type Pattern struct { type Pattern struct {
items []*item items []*item
} }
func CompilePattern(src string) *Pattern { func CompilePattern(src string) *Pattern {
items := make([]*item, 0) items := make([]*item, 0)
escape := false escape := false
inSet := false inSet := false
var set map[byte]bool var set map[byte]bool
for _, v := range src { for _, v := range src {
c := byte(v) c := byte(v)
if escape { if escape {
items = append(items, &item{typeCode: normal, character: c}) items = append(items, &item{typeCode: normal, character: c})
escape = false escape = false
} else if c == '*' { } else if c == '*' {
items = append(items, &item{typeCode: all}) items = append(items, &item{typeCode: all})
} else if c == '?' { } else if c == '?' {
items = append(items, &item{typeCode: any}) items = append(items, &item{typeCode: any})
} else if c == '\\' { } else if c == '\\' {
escape = true escape = true
} else if c == '[' { } else if c == '[' {
if !inSet { if !inSet {
inSet = true inSet = true
set = make(map[byte]bool) set = make(map[byte]bool)
} else { } else {
set[c] = true set[c] = true
} }
} else if c == ']' { } else if c == ']' {
if inSet { if inSet {
inSet = false inSet = false
items = append(items, &item{typeCode: set_, set: set}) items = append(items, &item{typeCode: set_, set: set})
} else { } else {
items = append(items, &item{typeCode: normal, character: c}) items = append(items, &item{typeCode: normal, character: c})
} }
} else { } else {
if inSet { if inSet {
set[c] = true set[c] = true
} else { } else {
items = append(items, &item{typeCode: normal, character: c}) items = append(items, &item{typeCode: normal, character: c})
} }
} }
} }
return &Pattern{ return &Pattern{
items: items, items: items,
} }
} }
func (p *Pattern) IsMatch(s string) bool { func (p *Pattern) IsMatch(s string) bool {
if len(p.items) == 0 { if len(p.items) == 0 {
return len(s) == 0 return len(s) == 0
} }
m := len(s) m := len(s)
n := len(p.items) n := len(p.items)
table := make([][]bool, m+1) table := make([][]bool, m+1)
for i := 0; i < m+1; i++ { for i := 0; i < m+1; i++ {
table[i] = make([]bool, n+1) table[i] = make([]bool, n+1)
} }
table[0][0] = true table[0][0] = true
for j := 1; j < n+1; j++ { for j := 1; j < n+1; j++ {
table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all
} }
for i := 1; i < m+1; i++ { for i := 1; i < m+1; i++ {
for j := 1; j < n+1; j++ { for j := 1; j < n+1; j++ {
if p.items[j-1].typeCode == all { if p.items[j-1].typeCode == all {
table[i][j] = table[i-1][j] || table[i][j-1] table[i][j] = table[i-1][j] || table[i][j-1]
} else { } else {
table[i][j] = table[i-1][j-1] && table[i][j] = table[i-1][j-1] &&
(p.items[j-1].typeCode == any || (p.items[j-1].typeCode == any ||
(p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) || (p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) ||
(p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1]))) (p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1])))
} }
} }
} }
return table[m][n] return table[m][n]
} }

View File

@@ -3,73 +3,73 @@ package wildcard
import "testing" import "testing"
func TestWildCard(t *testing.T) { func TestWildCard(t *testing.T) {
p := CompilePattern("a") p := CompilePattern("a")
if !p.IsMatch("a") { if !p.IsMatch("a") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if p.IsMatch("b") { if p.IsMatch("b") {
t.Error("expect false actually true") t.Error("expect false actually true")
} }
// test '?' // test '?'
p = CompilePattern("a?") p = CompilePattern("a?")
if !p.IsMatch("ab") { if !p.IsMatch("ab") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if p.IsMatch("a") { if p.IsMatch("a") {
t.Error("expect false actually true") t.Error("expect false actually true")
} }
if p.IsMatch("abb") { if p.IsMatch("abb") {
t.Error("expect false actually true") t.Error("expect false actually true")
} }
if p.IsMatch("bb") { if p.IsMatch("bb") {
t.Error("expect false actually true") t.Error("expect false actually true")
} }
// test * // test *
p = CompilePattern("a*") p = CompilePattern("a*")
if !p.IsMatch("ab") { if !p.IsMatch("ab") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if !p.IsMatch("a") { if !p.IsMatch("a") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if !p.IsMatch("abb") { if !p.IsMatch("abb") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if p.IsMatch("bb") { if p.IsMatch("bb") {
t.Error("expect false actually true") t.Error("expect false actually true")
} }
// test [] // test []
p = CompilePattern("a[ab[]") p = CompilePattern("a[ab[]")
if !p.IsMatch("ab") { if !p.IsMatch("ab") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if !p.IsMatch("aa") { if !p.IsMatch("aa") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if !p.IsMatch("a[") { if !p.IsMatch("a[") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if p.IsMatch("abb") { if p.IsMatch("abb") {
t.Error("expect false actually true") t.Error("expect false actually true")
} }
if p.IsMatch("bb") { if p.IsMatch("bb") {
t.Error("expect false actually true") t.Error("expect false actually true")
} }
// test escape // test escape
p = CompilePattern("\\\\") // pattern: \\ p = CompilePattern("\\\\") // pattern: \\
if !p.IsMatch("\\") { if !p.IsMatch("\\") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
p = CompilePattern("\\*") p = CompilePattern("\\*")
if !p.IsMatch("*") { if !p.IsMatch("*") {
t.Error("expect true actually false") t.Error("expect true actually false")
} }
if p.IsMatch("a") { if p.IsMatch("a") {
t.Error("expect false actually true") t.Error("expect false actually true")
} }
} }

View File

@@ -1,20 +1,20 @@
package pubsub package pubsub
import ( import (
"github.com/HDT3213/godis/src/datastruct/dict" "github.com/HDT3213/godis/src/datastruct/dict"
"github.com/HDT3213/godis/src/datastruct/lock" "github.com/HDT3213/godis/src/datastruct/lock"
) )
type Hub struct { type Hub struct {
// channel -> list(*Client) // channel -> list(*Client)
subs dict.Dict subs dict.Dict
// lock channel // lock channel
subsLocker *lock.Locks subsLocker *lock.Locks
} }
func MakeHub() *Hub { func MakeHub() *Hub {
return &Hub{ return &Hub{
subs: dict.MakeConcurrent(4), subs: dict.MakeConcurrent(4),
subsLocker: lock.Make(16), subsLocker: lock.Make(16),
} }
} }

View File

@@ -1,23 +1,23 @@
package pubsub package pubsub
import ( import (
"github.com/HDT3213/godis/src/datastruct/list" "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"strconv" "strconv"
) )
var ( var (
_subscribe = "subscribe" _subscribe = "subscribe"
_unsubscribe = "unsubscribe" _unsubscribe = "unsubscribe"
messageBytes = []byte("message") messageBytes = []byte("message")
unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n") unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n")
) )
func makeMsg(t string, channel string, code int64) []byte { func makeMsg(t string, channel string, code int64) []byte {
return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF + return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF +
"$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF + "$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF +
":" + strconv.FormatInt(code, 10) + reply.CRLF) ":" + strconv.FormatInt(code, 10) + reply.CRLF)
} }
/* /*
@@ -25,22 +25,22 @@ func makeMsg(t string, channel string, code int64) []byte {
* return: is new subscribed * return: is new subscribed
*/ */
func subscribe0(hub *Hub, channel string, client redis.Connection) bool { func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
client.SubsChannel(channel) client.SubsChannel(channel)
// add into hub.subs // add into hub.subs
raw, ok := hub.subs.Get(channel) raw, ok := hub.subs.Get(channel)
var subscribers *list.LinkedList var subscribers *list.LinkedList
if ok { if ok {
subscribers, _ = raw.(*list.LinkedList) subscribers, _ = raw.(*list.LinkedList)
} else { } else {
subscribers = list.Make() subscribers = list.Make()
hub.subs.Put(channel, subscribers) hub.subs.Put(channel, subscribers)
} }
if subscribers.Contains(client) { if subscribers.Contains(client) {
return false return false
} }
subscribers.Add(client) subscribers.Add(client)
return true return true
} }
/* /*
@@ -48,102 +48,102 @@ func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
* return: is actually un-subscribe * return: is actually un-subscribe
*/ */
func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool { func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool {
client.UnSubsChannel(channel) client.UnSubsChannel(channel)
// remove from hub.subs // remove from hub.subs
raw, ok := hub.subs.Get(channel) raw, ok := hub.subs.Get(channel)
if ok { if ok {
subscribers, _ := raw.(*list.LinkedList) subscribers, _ := raw.(*list.LinkedList)
subscribers.RemoveAllByVal(client) subscribers.RemoveAllByVal(client)
if subscribers.Len() == 0 { if subscribers.Len() == 0 {
// clean // clean
hub.subs.Remove(channel) hub.subs.Remove(channel)
} }
return true return true
} }
return false return false
} }
func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply { func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
channels := make([]string, len(args)) channels := make([]string, len(args))
for i, b := range args { for i, b := range args {
channels[i] = string(b) channels[i] = string(b)
} }
hub.subsLocker.Locks(channels...) hub.subsLocker.Locks(channels...)
defer hub.subsLocker.UnLocks(channels...) defer hub.subsLocker.UnLocks(channels...)
for _, channel := range channels { for _, channel := range channels {
if subscribe0(hub, channel, c) { if subscribe0(hub, channel, c) {
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount()))) _ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
} }
} }
return &reply.NoReply{} return &reply.NoReply{}
} }
func UnsubscribeAll(hub *Hub, c redis.Connection) { func UnsubscribeAll(hub *Hub, c redis.Connection) {
channels := c.GetChannels() channels := c.GetChannels()
hub.subsLocker.Locks(channels...) hub.subsLocker.Locks(channels...)
defer hub.subsLocker.UnLocks(channels...) defer hub.subsLocker.UnLocks(channels...)
for _, channel := range channels { for _, channel := range channels {
unsubscribe0(hub, channel, c) unsubscribe0(hub, channel, c)
} }
} }
func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply { func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
var channels []string var channels []string
if len(args) > 0 { if len(args) > 0 {
channels = make([]string, len(args)) channels = make([]string, len(args))
for i, b := range args { for i, b := range args {
channels[i] = string(b) channels[i] = string(b)
} }
} else { } else {
channels = c.GetChannels() channels = c.GetChannels()
} }
db.subsLocker.Locks(channels...) db.subsLocker.Locks(channels...)
defer db.subsLocker.UnLocks(channels...) defer db.subsLocker.UnLocks(channels...)
if len(channels) == 0 { if len(channels) == 0 {
_ = c.Write(unSubscribeNothing) _ = c.Write(unSubscribeNothing)
return &reply.NoReply{} return &reply.NoReply{}
} }
for _, channel := range channels { for _, channel := range channels {
if unsubscribe0(db, channel, c) { if unsubscribe0(db, channel, c) {
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount()))) _ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
} }
} }
return &reply.NoReply{} return &reply.NoReply{}
} }
func Publish(hub *Hub, args [][]byte) redis.Reply { func Publish(hub *Hub, args [][]byte) redis.Reply {
if len(args) != 2 { if len(args) != 2 {
return &reply.ArgNumErrReply{Cmd: "publish"} return &reply.ArgNumErrReply{Cmd: "publish"}
} }
channel := string(args[0]) channel := string(args[0])
message := args[1] message := args[1]
hub.subsLocker.Lock(channel) hub.subsLocker.Lock(channel)
defer hub.subsLocker.UnLock(channel) defer hub.subsLocker.UnLock(channel)
raw, ok := hub.subs.Get(channel) raw, ok := hub.subs.Get(channel)
if !ok { if !ok {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
subscribers, _ := raw.(*list.LinkedList) subscribers, _ := raw.(*list.LinkedList)
subscribers.ForEach(func(i int, c interface{}) bool { subscribers.ForEach(func(i int, c interface{}) bool {
client, _ := c.(redis.Connection) client, _ := c.(redis.Connection)
replyArgs := make([][]byte, 3) replyArgs := make([][]byte, 3)
replyArgs[0] = messageBytes replyArgs[0] = messageBytes
replyArgs[1] = []byte(channel) replyArgs[1] = []byte(channel)
replyArgs[2] = message replyArgs[2] = message
_ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes()) _ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes())
return true return true
}) })
return reply.MakeIntReply(int64(subscribers.Len())) return reply.MakeIntReply(int64(subscribers.Len()))
} }

View File

@@ -1,322 +1,321 @@
package client package client
import ( import (
"bufio" "bufio"
"context" "context"
"errors" "errors"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/wait" "github.com/HDT3213/godis/src/lib/sync/wait"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"io" "io"
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
) )
type Client struct { type Client struct {
conn net.Conn conn net.Conn
sendingReqs chan *Request // waiting sending sendingReqs chan *Request // waiting sending
waitingReqs chan *Request // waiting response waitingReqs chan *Request // waiting response
ticker *time.Ticker ticker *time.Ticker
addr string addr string
ctx context.Context ctx context.Context
cancelFunc context.CancelFunc cancelFunc context.CancelFunc
writing *sync.WaitGroup writing *sync.WaitGroup
} }
type Request struct { type Request struct {
id uint64 id uint64
args [][]byte args [][]byte
reply redis.Reply reply redis.Reply
heartbeat bool heartbeat bool
waiting *wait.Wait waiting *wait.Wait
err error err error
} }
const ( const (
chanSize = 256 chanSize = 256
maxWait = 3 * time.Second maxWait = 3 * time.Second
) )
func MakeClient(addr string) (*Client, error) { func MakeClient(addr string) (*Client, error) {
conn, err := net.Dial("tcp", addr) conn, err := net.Dial("tcp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &Client{ return &Client{
addr: addr, addr: addr,
conn: conn, conn: conn,
sendingReqs: make(chan *Request, chanSize), sendingReqs: make(chan *Request, chanSize),
waitingReqs: make(chan *Request, chanSize), waitingReqs: make(chan *Request, chanSize),
ctx: ctx, ctx: ctx,
cancelFunc: cancel, cancelFunc: cancel,
writing: &sync.WaitGroup{}, writing: &sync.WaitGroup{},
}, nil }, nil
} }
func (client *Client) Start() { func (client *Client) Start() {
client.ticker = time.NewTicker(10 * time.Second) client.ticker = time.NewTicker(10 * time.Second)
go client.handleWrite() go client.handleWrite()
go func() { go func() {
err := client.handleRead() err := client.handleRead()
logger.Warn(err) logger.Warn(err)
}() }()
go client.heartbeat() go client.heartbeat()
} }
func (client *Client) Close() { func (client *Client) Close() {
// stop new request // stop new request
close(client.sendingReqs) close(client.sendingReqs)
// wait stop process // wait stop process
client.writing.Wait() client.writing.Wait()
// clean // clean
client.cancelFunc() client.cancelFunc()
_ = client.conn.Close() _ = client.conn.Close()
close(client.waitingReqs) close(client.waitingReqs)
} }
func (client *Client) handleConnectionError(err error) error { func (client *Client) handleConnectionError(err error) error {
err1 := client.conn.Close() err1 := client.conn.Close()
if err1 != nil { if err1 != nil {
if opErr, ok := err1.(*net.OpError); ok { if opErr, ok := err1.(*net.OpError); ok {
if opErr.Err.Error() != "use of closed network connection" { if opErr.Err.Error() != "use of closed network connection" {
return err1 return err1
} }
} else { } else {
return err1 return err1
} }
} }
conn, err1 := net.Dial("tcp", client.addr) conn, err1 := net.Dial("tcp", client.addr)
if err1 != nil { if err1 != nil {
logger.Error(err1) logger.Error(err1)
return err1 return err1
} }
client.conn = conn client.conn = conn
go func() { go func() {
_ = client.handleRead() _ = client.handleRead()
}() }()
return nil return nil
} }
func (client *Client) heartbeat() { func (client *Client) heartbeat() {
loop: loop:
for { for {
select { select {
case <-client.ticker.C: case <-client.ticker.C:
client.sendingReqs <- &Request{ client.sendingReqs <- &Request{
args: [][]byte{[]byte("PING")}, args: [][]byte{[]byte("PING")},
heartbeat: true, heartbeat: true,
} }
case <-client.ctx.Done(): case <-client.ctx.Done():
break loop break loop
} }
} }
} }
func (client *Client) handleWrite() { func (client *Client) handleWrite() {
loop: loop:
for { for {
select { select {
case req := <-client.sendingReqs: case req := <-client.sendingReqs:
client.writing.Add(1) client.writing.Add(1)
client.doRequest(req) client.doRequest(req)
case <-client.ctx.Done(): case <-client.ctx.Done():
break loop break loop
} }
} }
} }
// todo: wait with timeout // todo: wait with timeout
func (client *Client) Send(args [][]byte) redis.Reply { func (client *Client) Send(args [][]byte) redis.Reply {
request := &Request{ request := &Request{
args: args, args: args,
heartbeat: false, heartbeat: false,
waiting: &wait.Wait{}, waiting: &wait.Wait{},
} }
request.waiting.Add(1) request.waiting.Add(1)
client.sendingReqs <- request client.sendingReqs <- request
timeout := request.waiting.WaitWithTimeout(maxWait) timeout := request.waiting.WaitWithTimeout(maxWait)
if timeout { if timeout {
return reply.MakeErrReply("server time out") return reply.MakeErrReply("server time out")
} }
if request.err != nil { if request.err != nil {
return reply.MakeErrReply("request failed") return reply.MakeErrReply("request failed")
} }
return request.reply return request.reply
} }
func (client *Client) doRequest(req *Request) { func (client *Client) doRequest(req *Request) {
bytes := reply.MakeMultiBulkReply(req.args).ToBytes() bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
_, err := client.conn.Write(bytes) _, err := client.conn.Write(bytes)
i := 0 i := 0
for err != nil && i < 3 { for err != nil && i < 3 {
err = client.handleConnectionError(err) err = client.handleConnectionError(err)
if err == nil { if err == nil {
_, err = client.conn.Write(bytes) _, err = client.conn.Write(bytes)
} }
i++ i++
} }
if err == nil { if err == nil {
client.waitingReqs <- req client.waitingReqs <- req
} else { } else {
req.err = err req.err = err
req.waiting.Done() req.waiting.Done()
client.writing.Done() client.writing.Done()
} }
} }
func (client *Client) finishRequest(reply redis.Reply) { func (client *Client) finishRequest(reply redis.Reply) {
request := <-client.waitingReqs request := <-client.waitingReqs
request.reply = reply request.reply = reply
if request.waiting != nil { if request.waiting != nil {
request.waiting.Done() request.waiting.Done()
} }
client.writing.Done() client.writing.Done()
} }
func (client *Client) handleRead() error { func (client *Client) handleRead() error {
reader := bufio.NewReader(client.conn) reader := bufio.NewReader(client.conn)
downloading := false downloading := false
expectedArgsCount := 0 expectedArgsCount := 0
receivedCount := 0 receivedCount := 0
msgType := byte(0) // first char of msg msgType := byte(0) // first char of msg
var args [][]byte var args [][]byte
var fixedLen int64 = 0 var fixedLen int64 = 0
var err error var err error
var msg []byte var msg []byte
for { for {
// read line // read line
if fixedLen == 0 { // read normal line if fixedLen == 0 { // read normal line
msg, err = reader.ReadBytes('\n') msg, err = reader.ReadBytes('\n')
if err != nil { if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF { if err == io.EOF || err == io.ErrUnexpectedEOF {
logger.Info("connection close") logger.Info("connection close")
} else { } else {
logger.Warn(err) logger.Warn(err)
} }
return errors.New("connection closed") return errors.New("connection closed")
} }
if len(msg) == 0 || msg[len(msg)-2] != '\r' { if len(msg) == 0 || msg[len(msg)-2] != '\r' {
return errors.New("protocol error") return errors.New("protocol error")
} }
} else { // read bulk line (binary safe) } else { // read bulk line (binary safe)
msg = make([]byte, fixedLen+2) msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg) _, err = io.ReadFull(reader, msg)
if err != nil { if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF { if err == io.EOF || err == io.ErrUnexpectedEOF {
return errors.New("connection closed") return errors.New("connection closed")
} else { } else {
return err return err
} }
} }
if len(msg) == 0 || if len(msg) == 0 ||
msg[len(msg)-2] != '\r' || msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' { msg[len(msg)-1] != '\n' {
return errors.New("protocol error") return errors.New("protocol error")
} }
fixedLen = 0 fixedLen = 0
} }
// parse line // parse line
if !downloading { if !downloading {
// receive new response // receive new response
if msg[0] == '*' { // multi bulk response if msg[0] == '*' { // multi bulk response
// bulk multi msg // bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil { if err != nil {
return errors.New("protocol error: " + err.Error()) return errors.New("protocol error: " + err.Error())
} }
if expectedLine == 0 { if expectedLine == 0 {
client.finishRequest(&reply.EmptyMultiBulkReply{}) client.finishRequest(&reply.EmptyMultiBulkReply{})
} else if expectedLine > 0 { } else if expectedLine > 0 {
msgType = msg[0] msgType = msg[0]
downloading = true downloading = true
expectedArgsCount = int(expectedLine) expectedArgsCount = int(expectedLine)
receivedCount = 0 receivedCount = 0
args = make([][]byte, expectedLine) args = make([][]byte, expectedLine)
} else { } else {
return errors.New("protocol error") return errors.New("protocol error")
} }
} else if msg[0] == '$' { // bulk response } else if msg[0] == '$' { // bulk response
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
if err != nil { if err != nil {
return err return err
} }
if fixedLen == -1 { // null bulk if fixedLen == -1 { // null bulk
client.finishRequest(&reply.NullBulkReply{}) client.finishRequest(&reply.NullBulkReply{})
fixedLen = 0 fixedLen = 0
} else if fixedLen > 0 { } else if fixedLen > 0 {
msgType = msg[0] msgType = msg[0]
downloading = true downloading = true
expectedArgsCount = 1 expectedArgsCount = 1
receivedCount = 0 receivedCount = 0
args = make([][]byte, 1) args = make([][]byte, 1)
} else { } else {
return errors.New("protocol error") return errors.New("protocol error")
} }
} else { // single line response } else { // single line response
str := strings.TrimSuffix(string(msg), "\n") str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r") str = strings.TrimSuffix(str, "\r")
var result redis.Reply var result redis.Reply
switch msg[0] { switch msg[0] {
case '+': case '+':
result = reply.MakeStatusReply(str[1:]) result = reply.MakeStatusReply(str[1:])
case '-': case '-':
result = reply.MakeErrReply(str[1:]) result = reply.MakeErrReply(str[1:])
case ':': case ':':
val, err := strconv.ParseInt(str[1:], 10, 64) val, err := strconv.ParseInt(str[1:], 10, 64)
if err != nil { if err != nil {
return errors.New("protocol error") return errors.New("protocol error")
} }
result = reply.MakeIntReply(val) result = reply.MakeIntReply(val)
} }
client.finishRequest(result) client.finishRequest(result)
} }
} else { } else {
// receive following part of a request // receive following part of a request
line := msg[0 : len(msg)-2] line := msg[0 : len(msg)-2]
if line[0] == '$' { if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil { if err != nil {
return err return err
} }
if fixedLen <= 0 { // null bulk in multi bulks if fixedLen <= 0 { // null bulk in multi bulks
args[receivedCount] = []byte{} args[receivedCount] = []byte{}
receivedCount++ receivedCount++
fixedLen = 0 fixedLen = 0
} }
} else { } else {
args[receivedCount] = line args[receivedCount] = line
receivedCount++ receivedCount++
} }
// if sending finished // if sending finished
if receivedCount == expectedArgsCount { if receivedCount == expectedArgsCount {
downloading = false // finish downloading progress downloading = false // finish downloading progress
if msgType == '*' { if msgType == '*' {
reply := reply.MakeMultiBulkReply(args) reply := reply.MakeMultiBulkReply(args)
client.finishRequest(reply) client.finishRequest(reply)
} else if msgType == '$' { } else if msgType == '$' {
reply := reply.MakeBulkReply(args[0]) reply := reply.MakeBulkReply(args[0])
client.finishRequest(reply) client.finishRequest(reply)
} }
// finish reply
// finish reply expectedArgsCount = 0
expectedArgsCount = 0 receivedCount = 0
receivedCount = 0 args = nil
args = nil msgType = byte(0)
msgType = byte(0) }
} }
} }
}
} }

View File

@@ -1,104 +1,104 @@
package client package client
import ( import (
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"testing" "testing"
) )
func TestClient(t *testing.T) { func TestClient(t *testing.T) {
logger.Setup(&logger.Settings{ logger.Setup(&logger.Settings{
Path: "logs", Path: "logs",
Name: "godis", Name: "godis",
Ext: ".log", Ext: ".log",
TimeFormat: "2006-01-02", TimeFormat: "2006-01-02",
}) })
client, err := MakeClient("localhost:6379") client, err := MakeClient("localhost:6379")
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
client.Start() client.Start()
result := client.Send([][]byte{ result := client.Send([][]byte{
[]byte("PING"), []byte("PING"),
}) })
if statusRet, ok := result.(*reply.StatusReply); ok { if statusRet, ok := result.(*reply.StatusReply); ok {
if statusRet.Status != "PONG" { if statusRet.Status != "PONG" {
t.Error("`ping` failed, result: " + statusRet.Status) t.Error("`ping` failed, result: " + statusRet.Status)
} }
} }
result = client.Send([][]byte{ result = client.Send([][]byte{
[]byte("SET"), []byte("SET"),
[]byte("a"), []byte("a"),
[]byte("a"), []byte("a"),
}) })
if statusRet, ok := result.(*reply.StatusReply); ok { if statusRet, ok := result.(*reply.StatusReply); ok {
if statusRet.Status != "OK" { if statusRet.Status != "OK" {
t.Error("`set` failed, result: " + statusRet.Status) t.Error("`set` failed, result: " + statusRet.Status)
} }
} }
result = client.Send([][]byte{ result = client.Send([][]byte{
[]byte("GET"), []byte("GET"),
[]byte("a"), []byte("a"),
}) })
if bulkRet, ok := result.(*reply.BulkReply); ok { if bulkRet, ok := result.(*reply.BulkReply); ok {
if string(bulkRet.Arg) != "a" { if string(bulkRet.Arg) != "a" {
t.Error("`get` failed, result: " + string(bulkRet.Arg)) t.Error("`get` failed, result: " + string(bulkRet.Arg))
} }
} }
result = client.Send([][]byte{ result = client.Send([][]byte{
[]byte("DEL"), []byte("DEL"),
[]byte("a"), []byte("a"),
}) })
if intRet, ok := result.(*reply.IntReply); ok { if intRet, ok := result.(*reply.IntReply); ok {
if intRet.Code != 1 { if intRet.Code != 1 {
t.Error("`del` failed, result: " + string(intRet.Code)) t.Error("`del` failed, result: " + string(intRet.Code))
} }
} }
result = client.Send([][]byte{ result = client.Send([][]byte{
[]byte("GET"), []byte("GET"),
[]byte("a"), []byte("a"),
}) })
if _, ok := result.(*reply.NullBulkReply); !ok { if _, ok := result.(*reply.NullBulkReply); !ok {
t.Error("`get` failed, result: " + string(result.ToBytes())) t.Error("`get` failed, result: " + string(result.ToBytes()))
} }
result = client.Send([][]byte{ result = client.Send([][]byte{
[]byte("DEL"), []byte("DEL"),
[]byte("arr"), []byte("arr"),
}) })
result = client.Send([][]byte{ result = client.Send([][]byte{
[]byte("RPUSH"), []byte("RPUSH"),
[]byte("arr"), []byte("arr"),
[]byte("1"), []byte("1"),
[]byte("2"), []byte("2"),
[]byte("c"), []byte("c"),
}) })
if intRet, ok := result.(*reply.IntReply); ok { if intRet, ok := result.(*reply.IntReply); ok {
if intRet.Code != 3 { if intRet.Code != 3 {
t.Error("`rpush` failed, result: " + string(intRet.Code)) t.Error("`rpush` failed, result: " + string(intRet.Code))
} }
} }
result = client.Send([][]byte{ result = client.Send([][]byte{
[]byte("LRANGE"), []byte("LRANGE"),
[]byte("arr"), []byte("arr"),
[]byte("0"), []byte("0"),
[]byte("-1"), []byte("-1"),
}) })
if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok { if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok {
if len(multiBulkRet.Args) != 3 || if len(multiBulkRet.Args) != 3 ||
string(multiBulkRet.Args[0]) != "1" || string(multiBulkRet.Args[0]) != "1" ||
string(multiBulkRet.Args[1]) != "2" || string(multiBulkRet.Args[1]) != "2" ||
string(multiBulkRet.Args[2]) != "c" { string(multiBulkRet.Args[2]) != "c" {
t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes())) t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes()))
} }
} }
client.Close() client.Close()
} }

View File

@@ -1,42 +1,42 @@
package reply package reply
type PongReply struct {} type PongReply struct{}
var PongBytes = []byte("+PONG\r\n") var PongBytes = []byte("+PONG\r\n")
func (r *PongReply)ToBytes()[]byte { func (r *PongReply) ToBytes() []byte {
return PongBytes return PongBytes
} }
type OkReply struct {} type OkReply struct{}
var okBytes = []byte("+OK\r\n") var okBytes = []byte("+OK\r\n")
func (r *OkReply)ToBytes()[]byte { func (r *OkReply) ToBytes() []byte {
return okBytes return okBytes
} }
var nullBulkBytes = []byte("$-1\r\n") var nullBulkBytes = []byte("$-1\r\n")
type NullBulkReply struct {} type NullBulkReply struct{}
func (r *NullBulkReply)ToBytes()[]byte { func (r *NullBulkReply) ToBytes() []byte {
return nullBulkBytes return nullBulkBytes
} }
var emptyMultiBulkBytes = []byte("*0\r\n") var emptyMultiBulkBytes = []byte("*0\r\n")
type EmptyMultiBulkReply struct {} type EmptyMultiBulkReply struct{}
func (r *EmptyMultiBulkReply)ToBytes()[]byte { func (r *EmptyMultiBulkReply) ToBytes() []byte {
return emptyMultiBulkBytes return emptyMultiBulkBytes
} }
// reply nothing, for commands like subscribe // reply nothing, for commands like subscribe
type NoReply struct {} type NoReply struct{}
var NoBytes = []byte("") var NoBytes = []byte("")
func (r *NoReply)ToBytes()[]byte { func (r *NoReply) ToBytes() []byte {
return NoBytes return NoBytes
} }

View File

@@ -1,67 +1,67 @@
package reply package reply
// UnknownErr // UnknownErr
type UnknownErrReply struct {} type UnknownErrReply struct{}
var unknownErrBytes = []byte("-Err unknown\r\n") var unknownErrBytes = []byte("-Err unknown\r\n")
func (r *UnknownErrReply)ToBytes()[]byte { func (r *UnknownErrReply) ToBytes() []byte {
return unknownErrBytes return unknownErrBytes
} }
func (r *UnknownErrReply) Error()string { func (r *UnknownErrReply) Error() string {
return "Err unknown" return "Err unknown"
} }
// ArgNumErr // ArgNumErr
type ArgNumErrReply struct { type ArgNumErrReply struct {
Cmd string Cmd string
} }
func (r *ArgNumErrReply)ToBytes()[]byte { func (r *ArgNumErrReply) ToBytes() []byte {
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n") return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
} }
func (r *ArgNumErrReply) Error()string { func (r *ArgNumErrReply) Error() string {
return "ERR wrong number of arguments for '" + r.Cmd + "' command" return "ERR wrong number of arguments for '" + r.Cmd + "' command"
} }
// SyntaxErr // SyntaxErr
type SyntaxErrReply struct {} type SyntaxErrReply struct{}
var syntaxErrBytes = []byte("-Err syntax error\r\n") var syntaxErrBytes = []byte("-Err syntax error\r\n")
func (r *SyntaxErrReply)ToBytes()[]byte { func (r *SyntaxErrReply) ToBytes() []byte {
return syntaxErrBytes return syntaxErrBytes
} }
func (r *SyntaxErrReply)Error()string { func (r *SyntaxErrReply) Error() string {
return "Err syntax error" return "Err syntax error"
} }
// WrongTypeErr // WrongTypeErr
type WrongTypeErrReply struct {} type WrongTypeErrReply struct{}
var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n") var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
func (r *WrongTypeErrReply)ToBytes()[]byte { func (r *WrongTypeErrReply) ToBytes() []byte {
return wrongTypeErrBytes return wrongTypeErrBytes
} }
func (r *WrongTypeErrReply)Error()string { func (r *WrongTypeErrReply) Error() string {
return "WRONGTYPE Operation against a key holding the wrong kind of value" return "WRONGTYPE Operation against a key holding the wrong kind of value"
} }
// ProtocolErr // ProtocolErr
type ProtocolErrReply struct { type ProtocolErrReply struct {
Msg string Msg string
} }
func (r *ProtocolErrReply)ToBytes()[]byte { func (r *ProtocolErrReply) ToBytes() []byte {
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n") return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
} }
func (r *ProtocolErrReply) Error()string { func (r *ProtocolErrReply) Error() string {
return "ERR Protocol error: '" + r.Msg return "ERR Protocol error: '" + r.Msg
} }

View File

@@ -1,141 +1,140 @@
package reply package reply
import ( import (
"bytes" "bytes"
"github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/interface/redis"
"strconv" "strconv"
) )
var ( var (
nullBulkReplyBytes = []byte("$-1") nullBulkReplyBytes = []byte("$-1")
CRLF = "\r\n" CRLF = "\r\n"
) )
/* ---- Bulk Reply ---- */ /* ---- Bulk Reply ---- */
type BulkReply struct { type BulkReply struct {
Arg []byte Arg []byte
} }
func MakeBulkReply(arg []byte) *BulkReply { func MakeBulkReply(arg []byte) *BulkReply {
return &BulkReply{ return &BulkReply{
Arg: arg, Arg: arg,
} }
} }
func (r *BulkReply) ToBytes() []byte { func (r *BulkReply) ToBytes() []byte {
if len(r.Arg) == 0 { if len(r.Arg) == 0 {
return nullBulkReplyBytes return nullBulkReplyBytes
} }
return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF) return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF)
} }
/* ---- Multi Bulk Reply ---- */ /* ---- Multi Bulk Reply ---- */
type MultiBulkReply struct { type MultiBulkReply struct {
Args [][]byte Args [][]byte
} }
func MakeMultiBulkReply(args [][]byte) *MultiBulkReply { func MakeMultiBulkReply(args [][]byte) *MultiBulkReply {
return &MultiBulkReply{ return &MultiBulkReply{
Args: args, Args: args,
} }
} }
func (r *MultiBulkReply) ToBytes() []byte { func (r *MultiBulkReply) ToBytes() []byte {
argLen := len(r.Args) argLen := len(r.Args)
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF) buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
for _, arg := range r.Args { for _, arg := range r.Args {
if arg == nil { if arg == nil {
buf.WriteString("$-1" + CRLF) buf.WriteString("$-1" + CRLF)
} else { } else {
buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF) buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF)
} }
} }
return buf.Bytes() return buf.Bytes()
} }
/* ---- Multi Raw Reply ---- */ /* ---- Multi Raw Reply ---- */
type MultiRawReply struct { type MultiRawReply struct {
Args [][]byte Args [][]byte
} }
func MakeMultiRawReply(args [][]byte) *MultiRawReply { func MakeMultiRawReply(args [][]byte) *MultiRawReply {
return &MultiRawReply{ return &MultiRawReply{
Args: args, Args: args,
} }
} }
func (r *MultiRawReply) ToBytes() []byte { func (r *MultiRawReply) ToBytes() []byte {
argLen := len(r.Args) argLen := len(r.Args)
var buf bytes.Buffer var buf bytes.Buffer
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF) buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
for _, arg := range r.Args { for _, arg := range r.Args {
buf.Write(arg) buf.Write(arg)
} }
return buf.Bytes() return buf.Bytes()
} }
/* ---- Status Reply ---- */ /* ---- Status Reply ---- */
type StatusReply struct { type StatusReply struct {
Status string Status string
} }
func MakeStatusReply(status string) *StatusReply { func MakeStatusReply(status string) *StatusReply {
return &StatusReply{ return &StatusReply{
Status: status, Status: status,
} }
} }
func (r *StatusReply) ToBytes() []byte { func (r *StatusReply) ToBytes() []byte {
return []byte("+" + r.Status + "\r\n") return []byte("+" + r.Status + "\r\n")
} }
/* ---- Int Reply ---- */ /* ---- Int Reply ---- */
type IntReply struct { type IntReply struct {
Code int64 Code int64
} }
func MakeIntReply(code int64) *IntReply { func MakeIntReply(code int64) *IntReply {
return &IntReply{ return &IntReply{
Code: code, Code: code,
} }
} }
func (r *IntReply) ToBytes() []byte { func (r *IntReply) ToBytes() []byte {
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF) return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
} }
/* ---- Error Reply ---- */ /* ---- Error Reply ---- */
type ErrorReply interface { type ErrorReply interface {
Error() string Error() string
ToBytes() []byte ToBytes() []byte
} }
type StandardErrReply struct { type StandardErrReply struct {
Status string Status string
} }
func MakeErrReply(status string) *StandardErrReply { func MakeErrReply(status string) *StandardErrReply {
return &StandardErrReply{ return &StandardErrReply{
Status: status, Status: status,
} }
} }
func IsErrorReply(reply redis.Reply) bool { func IsErrorReply(reply redis.Reply) bool {
return reply.ToBytes()[0] == '-' return reply.ToBytes()[0] == '-'
} }
func (r *StandardErrReply) ToBytes() []byte { func (r *StandardErrReply) ToBytes() []byte {
return []byte("-" + r.Status + "\r\n") return []byte("-" + r.Status + "\r\n")
} }
func (r *StandardErrReply) Error() string { func (r *StandardErrReply) Error() string {
return r.Status return r.Status
} }

View File

@@ -1,95 +1,95 @@
package server package server
import ( import (
"github.com/HDT3213/godis/src/lib/sync/atomic" "github.com/HDT3213/godis/src/lib/sync/atomic"
"github.com/HDT3213/godis/src/lib/sync/wait" "github.com/HDT3213/godis/src/lib/sync/wait"
"net" "net"
"sync" "sync"
"time" "time"
) )
// abstract of active client // abstract of active client
type Client struct { type Client struct {
conn net.Conn conn net.Conn
// waiting util reply finished // waiting util reply finished
waitingReply wait.Wait waitingReply wait.Wait
// is sending request in progress // is sending request in progress
uploading atomic.AtomicBool uploading atomic.AtomicBool
// multi bulk msg lineCount - 1(first line) // multi bulk msg lineCount - 1(first line)
expectedArgsCount uint32 expectedArgsCount uint32
// sent line count, exclude first line // sent line count, exclude first line
receivedCount uint32 receivedCount uint32
// sent lines, exclude first line // sent lines, exclude first line
args [][]byte args [][]byte
// lock while server sending response // lock while server sending response
mu sync.Mutex mu sync.Mutex
// subscribing channels // subscribing channels
subs map[string]bool subs map[string]bool
} }
func (c *Client)Close()error { func (c *Client) Close() error {
c.waitingReply.WaitWithTimeout(10 * time.Second) c.waitingReply.WaitWithTimeout(10 * time.Second)
_ = c.conn.Close() _ = c.conn.Close()
return nil return nil
} }
func MakeClient(conn net.Conn) *Client { func MakeClient(conn net.Conn) *Client {
return &Client{ return &Client{
conn: conn, conn: conn,
} }
} }
func (c *Client)Write(b []byte)error { func (c *Client) Write(b []byte) error {
if b == nil || len(b) == 0 { if b == nil || len(b) == 0 {
return nil return nil
} }
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
_, err := c.conn.Write(b) _, err := c.conn.Write(b)
return err return err
} }
func (c *Client)SubsChannel(channel string) { func (c *Client) SubsChannel(channel string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.subs == nil { if c.subs == nil {
c.subs = make(map[string]bool) c.subs = make(map[string]bool)
} }
c.subs[channel] = true c.subs[channel] = true
} }
func (c *Client)UnSubsChannel(channel string) { func (c *Client) UnSubsChannel(channel string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if c.subs == nil { if c.subs == nil {
return return
} }
delete(c.subs, channel) delete(c.subs, channel)
} }
func (c *Client)SubsCount()int { func (c *Client) SubsCount() int {
if c.subs == nil { if c.subs == nil {
return 0 return 0
} }
return len(c.subs) return len(c.subs)
} }
func (c *Client)GetChannels()[]string { func (c *Client) GetChannels() []string {
if c.subs == nil { if c.subs == nil {
return make([]string, 0) return make([]string, 0)
} }
channels := make([]string, len(c.subs)) channels := make([]string, len(c.subs))
i := 0 i := 0
for channel := range c.subs { for channel := range c.subs {
channels[i] = channel channels[i] = channel
i++ i++
} }
return channels return channels
} }

View File

@@ -5,192 +5,192 @@ package server
*/ */
import ( import (
"bufio" "bufio"
"context" "context"
"github.com/HDT3213/godis/src/cluster" "github.com/HDT3213/godis/src/cluster"
"github.com/HDT3213/godis/src/config" "github.com/HDT3213/godis/src/config"
DBImpl "github.com/HDT3213/godis/src/db" DBImpl "github.com/HDT3213/godis/src/db"
"github.com/HDT3213/godis/src/interface/db" "github.com/HDT3213/godis/src/interface/db"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/atomic" "github.com/HDT3213/godis/src/lib/sync/atomic"
"github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/redis/reply"
"io" "io"
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
) )
var ( var (
UnknownErrReplyBytes = []byte("-ERR unknown\r\n") UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
) )
type Handler struct { type Handler struct {
activeConn sync.Map // *client -> placeholder activeConn sync.Map // *client -> placeholder
db db.DB db db.DB
closing atomic.AtomicBool // refusing new client and new request closing atomic.AtomicBool // refusing new client and new request
} }
func MakeHandler() *Handler { func MakeHandler() *Handler {
var db db.DB var db db.DB
if config.Properties.Peers != nil && if config.Properties.Peers != nil &&
len(config.Properties.Peers) > 0 { len(config.Properties.Peers) > 0 {
db = cluster.MakeCluster() db = cluster.MakeCluster()
} else { } else {
db = DBImpl.MakeDB() db = DBImpl.MakeDB()
} }
return &Handler{ return &Handler{
db: db, db: db,
} }
} }
func (h *Handler) closeClient(client *Client) { func (h *Handler) closeClient(client *Client) {
_ = client.Close() _ = client.Close()
h.db.AfterClientClose(client) h.db.AfterClientClose(client)
h.activeConn.Delete(client) h.activeConn.Delete(client)
} }
func (h *Handler) Handle(ctx context.Context, conn net.Conn) { func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() { if h.closing.Get() {
// closing handler refuse new connection // closing handler refuse new connection
_ = conn.Close() _ = conn.Close()
} }
client := MakeClient(conn) client := MakeClient(conn)
h.activeConn.Store(client, 1) h.activeConn.Store(client, 1)
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
var fixedLen int64 = 0 var fixedLen int64 = 0
var err error var err error
var msg []byte var msg []byte
for { for {
if fixedLen == 0 { if fixedLen == 0 {
msg, err = reader.ReadBytes('\n') msg, err = reader.ReadBytes('\n')
if err != nil { if err != nil {
if err == io.EOF || if err == io.EOF ||
err == io.ErrUnexpectedEOF || err == io.ErrUnexpectedEOF ||
strings.Contains(err.Error(), "use of closed network connection") { strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("connection close") logger.Info("connection close")
} else { } else {
logger.Warn(err) logger.Warn(err)
} }
// after client close // after client close
h.closeClient(client) h.closeClient(client)
return // io error, disconnect with client return // io error, disconnect with client
} }
if len(msg) == 0 || msg[len(msg)-2] != '\r' { if len(msg) == 0 || msg[len(msg)-2] != '\r' {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes()) _, _ = client.conn.Write(errReply.ToBytes())
} }
} else { } else {
msg = make([]byte, fixedLen+2) msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg) _, err = io.ReadFull(reader, msg)
if err != nil { if err != nil {
if err == io.EOF || if err == io.EOF ||
err == io.ErrUnexpectedEOF || err == io.ErrUnexpectedEOF ||
strings.Contains(err.Error(), "use of closed network connection") { strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("connection close") logger.Info("connection close")
} else { } else {
logger.Warn(err) logger.Warn(err)
} }
// after client close // after client close
h.closeClient(client) h.closeClient(client)
return // io error, disconnect with client return // io error, disconnect with client
} }
if len(msg) == 0 || if len(msg) == 0 ||
msg[len(msg)-2] != '\r' || msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' { msg[len(msg)-1] != '\n' {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes()) _, _ = client.conn.Write(errReply.ToBytes())
} }
fixedLen = 0 fixedLen = 0
} }
if !client.uploading.Get() { if !client.uploading.Get() {
// new request // new request
if msg[0] == '*' { if msg[0] == '*' {
// bulk multi msg // bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil { if err != nil {
_, _ = client.conn.Write(UnknownErrReplyBytes) _, _ = client.conn.Write(UnknownErrReplyBytes)
continue continue
} }
client.waitingReply.Add(1) client.waitingReply.Add(1)
client.uploading.Set(true) client.uploading.Set(true)
client.expectedArgsCount = uint32(expectedLine) client.expectedArgsCount = uint32(expectedLine)
client.receivedCount = 0 client.receivedCount = 0
client.args = make([][]byte, expectedLine) client.args = make([][]byte, expectedLine)
} else { } else {
// text protocol // text protocol
// remove \r or \n or \r\n in the end of line // remove \r or \n or \r\n in the end of line
str := strings.TrimSuffix(string(msg), "\n") str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r") str = strings.TrimSuffix(str, "\r")
strs := strings.Split(str, " ") strs := strings.Split(str, " ")
args := make([][]byte, len(strs)) args := make([][]byte, len(strs))
for i, s := range strs { for i, s := range strs {
args[i] = []byte(s) args[i] = []byte(s)
} }
// send reply // send reply
result := h.db.Exec(client, args) result := h.db.Exec(client, args)
if result != nil { if result != nil {
_ = client.Write(result.ToBytes()) _ = client.Write(result.ToBytes())
} else { } else {
_ = client.Write(UnknownErrReplyBytes) _ = client.Write(UnknownErrReplyBytes)
} }
} }
} else { } else {
// receive following part of a request // receive following part of a request
line := msg[0 : len(msg)-2] line := msg[0 : len(msg)-2]
if line[0] == '$' { if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil { if err != nil {
errReply := &reply.ProtocolErrReply{Msg: err.Error()} errReply := &reply.ProtocolErrReply{Msg: err.Error()}
_, _ = client.conn.Write(errReply.ToBytes()) _, _ = client.conn.Write(errReply.ToBytes())
} }
if fixedLen <= 0 { if fixedLen <= 0 {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"} errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes()) _, _ = client.conn.Write(errReply.ToBytes())
} }
} else { } else {
client.args[client.receivedCount] = line client.args[client.receivedCount] = line
client.receivedCount++ client.receivedCount++
} }
// if sending finished // if sending finished
if client.receivedCount == client.expectedArgsCount { if client.receivedCount == client.expectedArgsCount {
client.uploading.Set(false) // finish sending progress client.uploading.Set(false) // finish sending progress
// send reply // send reply
result := h.db.Exec(client, client.args) result := h.db.Exec(client, client.args)
if result != nil { if result != nil {
_ = client.Write(result.ToBytes()) _ = client.Write(result.ToBytes())
} else { } else {
_ = client.Write(UnknownErrReplyBytes) _ = client.Write(UnknownErrReplyBytes)
} }
// finish reply // finish reply
client.expectedArgsCount = 0 client.expectedArgsCount = 0
client.receivedCount = 0 client.receivedCount = 0
client.args = nil client.args = nil
client.waitingReply.Done() client.waitingReply.Done()
} }
} }
} }
} }
func (h *Handler) Close() error { func (h *Handler) Close() error {
logger.Info("handler shuting down...") logger.Info("handler shuting down...")
h.closing.Set(true) h.closing.Set(true)
// TODO: concurrent wait // TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{}) bool { h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*Client) client := key.(*Client)
_ = client.Close() _ = client.Close()
return true return true
}) })
h.db.Close() h.db.Close()
return nil return nil
} }

View File

@@ -5,79 +5,78 @@ package tcp
*/ */
import ( import (
"net" "bufio"
"context" "context"
"bufio" "github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/sync/atomic"
"sync" "github.com/HDT3213/godis/src/lib/sync/wait"
"io" "io"
"github.com/HDT3213/godis/src/lib/sync/atomic" "net"
"time" "sync"
"github.com/HDT3213/godis/src/lib/sync/wait" "time"
) )
type EchoHandler struct { type EchoHandler struct {
activeConn sync.Map activeConn sync.Map
closing atomic.AtomicBool closing atomic.AtomicBool
} }
func MakeEchoHandler()(*EchoHandler) { func MakeEchoHandler() *EchoHandler {
return &EchoHandler{ return &EchoHandler{}
}
} }
type Client struct { type Client struct {
Conn net.Conn Conn net.Conn
Waiting wait.Wait Waiting wait.Wait
} }
func (c *Client)Close()error { func (c *Client) Close() error {
c.Waiting.WaitWithTimeout(10 * time.Second) c.Waiting.WaitWithTimeout(10 * time.Second)
c.Conn.Close() c.Conn.Close()
return nil return nil
} }
func (h *EchoHandler)Handle(ctx context.Context, conn net.Conn) { func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() { if h.closing.Get() {
// closing handler refuse new connection // closing handler refuse new connection
conn.Close() conn.Close()
} }
client := &Client { client := &Client{
Conn: conn, Conn: conn,
} }
h.activeConn.Store(client, 1) h.activeConn.Store(client, 1)
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
for { for {
// may occurs: client EOF, client timeout, server early close // may occurs: client EOF, client timeout, server early close
msg, err := reader.ReadString('\n') msg, err := reader.ReadString('\n')
if err != nil { if err != nil {
if err == io.EOF { if err == io.EOF {
logger.Info("connection close") logger.Info("connection close")
h.activeConn.Delete(client) h.activeConn.Delete(client)
} else { } else {
logger.Warn(err) logger.Warn(err)
} }
return return
} }
client.Waiting.Add(1) client.Waiting.Add(1)
//logger.Info("sleeping") //logger.Info("sleeping")
//time.Sleep(10 * time.Second) //time.Sleep(10 * time.Second)
b := []byte(msg) b := []byte(msg)
conn.Write(b) conn.Write(b)
client.Waiting.Done() client.Waiting.Done()
} }
} }
func (h *EchoHandler)Close()error { func (h *EchoHandler) Close() error {
logger.Info("handler shuting down...") logger.Info("handler shuting down...")
h.closing.Set(true) h.closing.Set(true)
// TODO: concurrent wait // TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{})bool { h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*Client) client := key.(*Client)
client.Close() client.Close()
return true return true
}) })
return nil return nil
} }

View File

@@ -5,75 +5,74 @@ package tcp
*/ */
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/HDT3213/godis/src/interface/tcp" "github.com/HDT3213/godis/src/interface/tcp"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/sync/atomic" "github.com/HDT3213/godis/src/lib/sync/atomic"
"net" "net"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
"syscall" "syscall"
"time" "time"
) )
type Config struct { type Config struct {
Address string `yaml:"address"` Address string `yaml:"address"`
MaxConnect uint32 `yaml:"max-connect"` MaxConnect uint32 `yaml:"max-connect"`
Timeout time.Duration `yaml:"timeout"` Timeout time.Duration `yaml:"timeout"`
} }
func ListenAndServe(cfg *Config, handler tcp.Handler) { func ListenAndServe(cfg *Config, handler tcp.Handler) {
listener, err := net.Listen("tcp", cfg.Address) listener, err := net.Listen("tcp", cfg.Address)
if err != nil { if err != nil {
logger.Fatal(fmt.Sprintf("listen err: %v", err)) logger.Fatal(fmt.Sprintf("listen err: %v", err))
} }
// listen signal // listen signal
var closing atomic.AtomicBool var closing atomic.AtomicBool
sigCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
go func() { go func() {
sig := <-sigCh sig := <-sigCh
switch sig { switch sig {
case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT: case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
logger.Info("shuting down...") logger.Info("shuting down...")
closing.Set(true) closing.Set(true)
_ = listener.Close() // listener.Accept() will return err immediately _ = listener.Close() // listener.Accept() will return err immediately
_ = handler.Close() // close connections _ = handler.Close() // close connections
} }
}() }()
// listen port
// listen port logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address)) defer func() {
defer func() { // close during unexpected error
// close during unexpected error _ = listener.Close()
_ = listener.Close() _ = handler.Close()
_ = handler.Close() }()
}() ctx, _ := context.WithCancel(context.Background())
ctx, _ := context.WithCancel(context.Background()) var waitDone sync.WaitGroup
var waitDone sync.WaitGroup for {
for { conn, err := listener.Accept()
conn, err := listener.Accept() if err != nil {
if err != nil { if closing.Get() {
if closing.Get() { logger.Info("waiting disconnect...")
logger.Info("waiting disconnect...") waitDone.Wait()
waitDone.Wait() return // handler will be closed by defer
return // handler will be closed by defer }
} logger.Error(fmt.Sprintf("accept err: %v", err))
logger.Error(fmt.Sprintf("accept err: %v", err)) continue
continue }
} // handle
// handle logger.Info("accept link")
logger.Info("accept link") waitDone.Add(1)
waitDone.Add(1) go func() {
go func() { defer func() {
defer func() { waitDone.Done()
waitDone.Done() }()
}() handler.Handle(ctx, conn)
handler.Handle(ctx, conn) }()
}() }
}
} }