mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 08:46:56 +08:00
reformat code
This commit is contained in:
17
README.md
17
README.md
@@ -2,11 +2,13 @@
|
||||
|
||||
[中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md)
|
||||
|
||||
`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent middleware using golang.
|
||||
`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent
|
||||
middleware using golang.
|
||||
|
||||
Please be advised, NEVER think about using this in production environment.
|
||||
|
||||
This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence and server side cluster mode.
|
||||
This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence
|
||||
and server side cluster mode.
|
||||
|
||||
If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html).
|
||||
|
||||
@@ -22,7 +24,7 @@ You could use redis-cli or other redis client to connect godis server, which lis
|
||||
|
||||
The program will try to read config file path from environment variable `CONFIG`.
|
||||
|
||||
If environment variable is not set, then the program try to read `redis.conf` in the working directory.
|
||||
If environment variable is not set, then the program try to read `redis.conf` in the working directory.
|
||||
|
||||
If there is no such file, then the program will run with default config.
|
||||
|
||||
@@ -35,8 +37,7 @@ peers localhost:7379,localhost:7389 // other node in cluster
|
||||
self localhost:6399 // self address
|
||||
```
|
||||
|
||||
We provide node1.conf and node2.conf for demonstration.
|
||||
use following command line to start a two-node-cluster:
|
||||
We provide node1.conf and node2.conf for demonstration. use following command line to start a two-node-cluster:
|
||||
|
||||
```bash
|
||||
CONFIG=node1.conf ./godis-darwin &
|
||||
@@ -151,7 +152,7 @@ Supported Commands:
|
||||
If you want to read my code in this repository, here is a simple guidance.
|
||||
|
||||
- cmd: only the entry point
|
||||
- config: config parser
|
||||
- config: config parser
|
||||
- interface: some interface definitions
|
||||
- lib: some utils, such as logger, sync utils and wildcard
|
||||
|
||||
@@ -167,7 +168,7 @@ I suggest focusing on the following directories:
|
||||
- sortedset: a sorted set implements based on skiplist
|
||||
- db: the implements of the redis db
|
||||
- db.go: the basement of database
|
||||
- router.go: it find handler for commands
|
||||
- router.go: it find handler for commands
|
||||
- keys.go: handlers for keys commands
|
||||
- string.go: handlers for string commands
|
||||
- list.go: handlers for list commands
|
||||
@@ -176,7 +177,7 @@ I suggest focusing on the following directories:
|
||||
- sortedset.go: handlers for sorted set commands
|
||||
- pubsub.go: implements of publish / subscribe
|
||||
- aof.go: implements of AOF persistence and rewrite
|
||||
|
||||
|
||||
# License
|
||||
|
||||
This project is licensed under the [GPL license](https://github.com/HDT3213/godis/blob/master/LICENSE).
|
34
README_CN.md
34
README_CN.md
@@ -2,8 +2,8 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试
|
||||
|
||||
**请注意:不要在生产环境使用使用此项目**
|
||||
|
||||
Godis 实现了 Redis 的大多数功能,包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于 Godis 的信息。
|
||||
|
||||
Godis 实现了 Redis 的大多数功能,包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于
|
||||
Godis 的信息。
|
||||
|
||||
# 运行 Godis
|
||||
|
||||
@@ -149,19 +149,19 @@ redis-cli -p 6399
|
||||
- tcp: tcp 服务器实现
|
||||
- redis: redis 协议解析器
|
||||
- datastruct: redis 的各类数据结构实现
|
||||
- dict: hash 表
|
||||
- list: 链表
|
||||
- lock: 用于锁定 key 的锁组件
|
||||
- set: 基于hash表的集合
|
||||
- sortedset: 基于跳表实现的有序集合
|
||||
- dict: hash 表
|
||||
- list: 链表
|
||||
- lock: 用于锁定 key 的锁组件
|
||||
- set: 基于hash表的集合
|
||||
- sortedset: 基于跳表实现的有序集合
|
||||
- db: redis 存储引擎实现
|
||||
- db.go: 引擎的基础功能
|
||||
- router.go: 将命令路由给响应的处理函数
|
||||
- keys.go: del、ttl、expire 等通用命令实现
|
||||
- string.go: get、set 等字符串命令实现
|
||||
- list.go: lpush、lindex 等列表命令实现
|
||||
- hash.go: hget、hset 等哈希表命令实现
|
||||
- set.go: sadd 等集合命令实现
|
||||
- sortedset.go: zadd 等有序集合命令实现
|
||||
- pubsub.go: 发布订阅命令实现
|
||||
- aof.go: aof持久化实现
|
||||
- db.go: 引擎的基础功能
|
||||
- router.go: 将命令路由给响应的处理函数
|
||||
- keys.go: del、ttl、expire 等通用命令实现
|
||||
- string.go: get、set 等字符串命令实现
|
||||
- list.go: lpush、lindex 等列表命令实现
|
||||
- hash.go: hget、hset 等哈希表命令实现
|
||||
- set.go: sadd 等集合命令实现
|
||||
- sortedset.go: zadd 等有序集合命令实现
|
||||
- pubsub.go: 发布订阅命令实现
|
||||
- aof.go: aof持久化实现
|
@@ -1,45 +1,45 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/HDT3213/godis/src/redis/client"
|
||||
"github.com/jolestar/go-commons-pool/v2"
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/HDT3213/godis/src/redis/client"
|
||||
"github.com/jolestar/go-commons-pool/v2"
|
||||
)
|
||||
|
||||
type ConnectionFactory struct {
|
||||
Peer string
|
||||
Peer string
|
||||
}
|
||||
|
||||
func (f *ConnectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, error) {
|
||||
c, err := client.MakeClient(f.Peer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Start()
|
||||
return pool.NewPooledObject(c), nil
|
||||
c, err := client.MakeClient(f.Peer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.Start()
|
||||
return pool.NewPooledObject(c), nil
|
||||
}
|
||||
|
||||
func (f *ConnectionFactory) DestroyObject(ctx context.Context, object *pool.PooledObject) error {
|
||||
c, ok := object.Object.(*client.Client)
|
||||
if !ok {
|
||||
return errors.New("type mismatch")
|
||||
}
|
||||
c.Close()
|
||||
return nil
|
||||
c, ok := object.Object.(*client.Client)
|
||||
if !ok {
|
||||
return errors.New("type mismatch")
|
||||
}
|
||||
c.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *ConnectionFactory) ValidateObject(ctx context.Context, object *pool.PooledObject) bool {
|
||||
// do validate
|
||||
return true
|
||||
// do validate
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *ConnectionFactory) ActivateObject(ctx context.Context, object *pool.PooledObject) error {
|
||||
// do activate
|
||||
return nil
|
||||
// do activate
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *ConnectionFactory) PassivateObject(ctx context.Context, object *pool.PooledObject) error {
|
||||
// do passivate
|
||||
return nil
|
||||
// do passivate
|
||||
return nil
|
||||
}
|
||||
|
@@ -1,99 +1,99 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
|
||||
}
|
||||
keys := make([]string, len(args)-1)
|
||||
for i := 1; i < len(args); i++ {
|
||||
keys[i-1] = string(args[i])
|
||||
}
|
||||
groupMap := cluster.groupBy(keys)
|
||||
if len(groupMap) == 1 { // do fast
|
||||
for peer, group := range groupMap { // only one group
|
||||
return cluster.Relay(peer, c, makeArgs("DEL", group...))
|
||||
}
|
||||
}
|
||||
// prepare
|
||||
var errReply redis.Reply
|
||||
txId := cluster.idGenerator.NextId()
|
||||
txIdStr := strconv.FormatInt(txId, 10)
|
||||
rollback := false
|
||||
for peer, group := range groupMap {
|
||||
args := []string{txIdStr}
|
||||
args = append(args, group...)
|
||||
var resp redis.Reply
|
||||
if peer == cluster.self {
|
||||
resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...))
|
||||
} else {
|
||||
resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...))
|
||||
}
|
||||
if reply.IsErrorReply(resp) {
|
||||
errReply = resp
|
||||
rollback = true
|
||||
break
|
||||
}
|
||||
}
|
||||
var respList []redis.Reply
|
||||
if rollback {
|
||||
// rollback
|
||||
RequestRollback(cluster, c, txId, groupMap)
|
||||
} else {
|
||||
// commit
|
||||
respList, errReply = RequestCommit(cluster, c, txId, groupMap)
|
||||
if errReply != nil {
|
||||
rollback = true
|
||||
}
|
||||
}
|
||||
if !rollback {
|
||||
var deleted int64 = 0
|
||||
for _, resp := range respList {
|
||||
intResp := resp.(*reply.IntReply)
|
||||
deleted += intResp.Code
|
||||
}
|
||||
return reply.MakeIntReply(int64(deleted))
|
||||
}
|
||||
return errReply
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
|
||||
}
|
||||
keys := make([]string, len(args)-1)
|
||||
for i := 1; i < len(args); i++ {
|
||||
keys[i-1] = string(args[i])
|
||||
}
|
||||
groupMap := cluster.groupBy(keys)
|
||||
if len(groupMap) == 1 { // do fast
|
||||
for peer, group := range groupMap { // only one group
|
||||
return cluster.Relay(peer, c, makeArgs("DEL", group...))
|
||||
}
|
||||
}
|
||||
// prepare
|
||||
var errReply redis.Reply
|
||||
txId := cluster.idGenerator.NextId()
|
||||
txIdStr := strconv.FormatInt(txId, 10)
|
||||
rollback := false
|
||||
for peer, group := range groupMap {
|
||||
args := []string{txIdStr}
|
||||
args = append(args, group...)
|
||||
var resp redis.Reply
|
||||
if peer == cluster.self {
|
||||
resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...))
|
||||
} else {
|
||||
resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...))
|
||||
}
|
||||
if reply.IsErrorReply(resp) {
|
||||
errReply = resp
|
||||
rollback = true
|
||||
break
|
||||
}
|
||||
}
|
||||
var respList []redis.Reply
|
||||
if rollback {
|
||||
// rollback
|
||||
RequestRollback(cluster, c, txId, groupMap)
|
||||
} else {
|
||||
// commit
|
||||
respList, errReply = RequestCommit(cluster, c, txId, groupMap)
|
||||
if errReply != nil {
|
||||
rollback = true
|
||||
}
|
||||
}
|
||||
if !rollback {
|
||||
var deleted int64 = 0
|
||||
for _, resp := range respList {
|
||||
intResp := resp.(*reply.IntReply)
|
||||
deleted += intResp.Code
|
||||
}
|
||||
return reply.MakeIntReply(int64(deleted))
|
||||
}
|
||||
return errReply
|
||||
}
|
||||
|
||||
// args: PrepareDel id keys...
|
||||
func PrepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args) < 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command")
|
||||
}
|
||||
txId := string(args[1])
|
||||
keys := make([]string, 0, len(args)-2)
|
||||
for i := 2; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
keys = append(keys, string(arg))
|
||||
}
|
||||
txArgs := makeArgs("DEL", keys...) // actual args for cluster.db
|
||||
tx := NewTransaction(cluster, c, txId, txArgs, keys)
|
||||
cluster.transactions.Put(txId, tx)
|
||||
err := tx.prepare()
|
||||
if err != nil {
|
||||
return reply.MakeErrReply(err.Error())
|
||||
}
|
||||
return &reply.OkReply{}
|
||||
if len(args) < 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command")
|
||||
}
|
||||
txId := string(args[1])
|
||||
keys := make([]string, 0, len(args)-2)
|
||||
for i := 2; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
keys = append(keys, string(arg))
|
||||
}
|
||||
txArgs := makeArgs("DEL", keys...) // actual args for cluster.db
|
||||
tx := NewTransaction(cluster, c, txId, txArgs, keys)
|
||||
cluster.transactions.Put(txId, tx)
|
||||
err := tx.prepare()
|
||||
if err != nil {
|
||||
return reply.MakeErrReply(err.Error())
|
||||
}
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
|
||||
// invoker should provide lock
|
||||
func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
|
||||
keys := make([]string, len(tx.args))
|
||||
for i, v := range tx.args {
|
||||
keys[i] = string(v)
|
||||
}
|
||||
keys = keys[1:]
|
||||
keys := make([]string, len(tx.args))
|
||||
for i, v := range tx.args {
|
||||
keys[i] = string(v)
|
||||
}
|
||||
keys = keys[1:]
|
||||
|
||||
deleted := cluster.db.Removes(keys...)
|
||||
if deleted > 0 {
|
||||
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
|
||||
}
|
||||
return reply.MakeIntReply(int64(deleted))
|
||||
deleted := cluster.db.Removes(keys...)
|
||||
if deleted > 0 {
|
||||
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
|
||||
}
|
||||
return reply.MakeIntReply(int64(deleted))
|
||||
}
|
||||
|
@@ -1,85 +1,85 @@
|
||||
package idgenerator
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
workerIdBits int64 = 5
|
||||
datacenterIdBits int64 = 5
|
||||
sequenceBits int64 = 12
|
||||
workerIdBits int64 = 5
|
||||
datacenterIdBits int64 = 5
|
||||
sequenceBits int64 = 12
|
||||
|
||||
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
|
||||
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
|
||||
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
|
||||
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
|
||||
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
|
||||
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
|
||||
|
||||
timeLeft uint8 = 22
|
||||
dataLeft uint8 = 17
|
||||
workLeft uint8 = 12
|
||||
timeLeft uint8 = 22
|
||||
dataLeft uint8 = 17
|
||||
workLeft uint8 = 12
|
||||
|
||||
twepoch int64 = 1525705533000
|
||||
twepoch int64 = 1525705533000
|
||||
)
|
||||
|
||||
type IdGenerator struct {
|
||||
mu *sync.Mutex
|
||||
lastStamp int64
|
||||
workerId int64
|
||||
dataCenterId int64
|
||||
sequence int64
|
||||
mu *sync.Mutex
|
||||
lastStamp int64
|
||||
workerId int64
|
||||
dataCenterId int64
|
||||
sequence int64
|
||||
}
|
||||
|
||||
func MakeGenerator(cluster string, node string) *IdGenerator {
|
||||
fnv64 := fnv.New64()
|
||||
_, _ = fnv64.Write([]byte(cluster))
|
||||
dataCenterId := int64(fnv64.Sum64())
|
||||
fnv64 := fnv.New64()
|
||||
_, _ = fnv64.Write([]byte(cluster))
|
||||
dataCenterId := int64(fnv64.Sum64())
|
||||
|
||||
fnv64.Reset()
|
||||
_, _ = fnv64.Write([]byte(node))
|
||||
workerId := int64(fnv64.Sum64())
|
||||
fnv64.Reset()
|
||||
_, _ = fnv64.Write([]byte(node))
|
||||
workerId := int64(fnv64.Sum64())
|
||||
|
||||
return &IdGenerator{
|
||||
mu: &sync.Mutex{},
|
||||
lastStamp: -1,
|
||||
dataCenterId: dataCenterId,
|
||||
workerId: workerId,
|
||||
sequence: 1,
|
||||
}
|
||||
return &IdGenerator{
|
||||
mu: &sync.Mutex{},
|
||||
lastStamp: -1,
|
||||
dataCenterId: dataCenterId,
|
||||
workerId: workerId,
|
||||
sequence: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *IdGenerator) getCurrentTime() int64 {
|
||||
return time.Now().UnixNano() / 1e6
|
||||
return time.Now().UnixNano() / 1e6
|
||||
}
|
||||
|
||||
func (w *IdGenerator) NextId() int64 {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
timestamp := w.getCurrentTime()
|
||||
if timestamp < w.lastStamp {
|
||||
log.Fatal("can not generate id")
|
||||
}
|
||||
if w.lastStamp == timestamp {
|
||||
w.sequence = (w.sequence + 1) & maxSequence
|
||||
if w.sequence == 0 {
|
||||
for timestamp <= w.lastStamp {
|
||||
timestamp = w.getCurrentTime()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
w.sequence = 0
|
||||
}
|
||||
w.lastStamp = timestamp
|
||||
timestamp := w.getCurrentTime()
|
||||
if timestamp < w.lastStamp {
|
||||
log.Fatal("can not generate id")
|
||||
}
|
||||
if w.lastStamp == timestamp {
|
||||
w.sequence = (w.sequence + 1) & maxSequence
|
||||
if w.sequence == 0 {
|
||||
for timestamp <= w.lastStamp {
|
||||
timestamp = w.getCurrentTime()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
w.sequence = 0
|
||||
}
|
||||
w.lastStamp = timestamp
|
||||
|
||||
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
|
||||
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
|
||||
}
|
||||
|
||||
func (w *IdGenerator) tilNextMillis() int64 {
|
||||
timestamp := w.getCurrentTime()
|
||||
if timestamp <= w.lastStamp {
|
||||
timestamp = w.getCurrentTime()
|
||||
}
|
||||
return timestamp
|
||||
timestamp := w.getCurrentTime()
|
||||
if timestamp <= w.lastStamp {
|
||||
timestamp = w.getCurrentTime()
|
||||
}
|
||||
return timestamp
|
||||
}
|
||||
|
@@ -1,159 +1,159 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/db"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/db"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
|
||||
}
|
||||
keys := make([]string, len(args)-1)
|
||||
for i := 1; i < len(args); i++ {
|
||||
keys[i-1] = string(args[i])
|
||||
}
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
|
||||
}
|
||||
keys := make([]string, len(args)-1)
|
||||
for i := 1; i < len(args); i++ {
|
||||
keys[i-1] = string(args[i])
|
||||
}
|
||||
|
||||
resultMap := make(map[string][]byte)
|
||||
groupMap := cluster.groupBy(keys)
|
||||
for peer, group := range groupMap {
|
||||
resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
|
||||
if reply.IsErrorReply(resp) {
|
||||
errReply := resp.(reply.ErrorReply)
|
||||
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
|
||||
}
|
||||
arrReply, _ := resp.(*reply.MultiBulkReply)
|
||||
for i, v := range arrReply.Args {
|
||||
key := group[i]
|
||||
resultMap[key] = v
|
||||
}
|
||||
}
|
||||
result := make([][]byte, len(keys))
|
||||
for i, k := range keys {
|
||||
result[i] = resultMap[k]
|
||||
}
|
||||
return reply.MakeMultiBulkReply(result)
|
||||
resultMap := make(map[string][]byte)
|
||||
groupMap := cluster.groupBy(keys)
|
||||
for peer, group := range groupMap {
|
||||
resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
|
||||
if reply.IsErrorReply(resp) {
|
||||
errReply := resp.(reply.ErrorReply)
|
||||
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
|
||||
}
|
||||
arrReply, _ := resp.(*reply.MultiBulkReply)
|
||||
for i, v := range arrReply.Args {
|
||||
key := group[i]
|
||||
resultMap[key] = v
|
||||
}
|
||||
}
|
||||
result := make([][]byte, len(keys))
|
||||
for i, k := range keys {
|
||||
result[i] = resultMap[k]
|
||||
}
|
||||
return reply.MakeMultiBulkReply(result)
|
||||
}
|
||||
|
||||
// args: PrepareMSet id keys...
|
||||
func PrepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args) < 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command")
|
||||
}
|
||||
txId := string(args[1])
|
||||
size := (len(args) - 2) / 2
|
||||
keys := make([]string, size)
|
||||
for i := 0; i < size; i++ {
|
||||
keys[i] = string(args[2*i+2])
|
||||
}
|
||||
if len(args) < 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command")
|
||||
}
|
||||
txId := string(args[1])
|
||||
size := (len(args) - 2) / 2
|
||||
keys := make([]string, size)
|
||||
for i := 0; i < size; i++ {
|
||||
keys[i] = string(args[2*i+2])
|
||||
}
|
||||
|
||||
txArgs := [][]byte{
|
||||
[]byte("MSet"),
|
||||
} // actual args for cluster.db
|
||||
txArgs = append(txArgs, args[2:]...)
|
||||
tx := NewTransaction(cluster, c, txId, txArgs, keys)
|
||||
cluster.transactions.Put(txId, tx)
|
||||
err := tx.prepare()
|
||||
if err != nil {
|
||||
return reply.MakeErrReply(err.Error())
|
||||
}
|
||||
return &reply.OkReply{}
|
||||
txArgs := [][]byte{
|
||||
[]byte("MSet"),
|
||||
} // actual args for cluster.db
|
||||
txArgs = append(txArgs, args[2:]...)
|
||||
tx := NewTransaction(cluster, c, txId, txArgs, keys)
|
||||
cluster.transactions.Put(txId, tx)
|
||||
err := tx.prepare()
|
||||
if err != nil {
|
||||
return reply.MakeErrReply(err.Error())
|
||||
}
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
|
||||
// invoker should provide lock
|
||||
func CommitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
|
||||
size := len(tx.args) / 2
|
||||
keys := make([]string, size)
|
||||
values := make([][]byte, size)
|
||||
for i := 0; i < size; i++ {
|
||||
keys[i] = string(tx.args[2*i+1])
|
||||
values[i] = tx.args[2*i+2]
|
||||
}
|
||||
for i, key := range keys {
|
||||
value := values[i]
|
||||
cluster.db.Put(key, &db.DataEntity{Data: value})
|
||||
}
|
||||
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
|
||||
return &reply.OkReply{}
|
||||
size := len(tx.args) / 2
|
||||
keys := make([]string, size)
|
||||
values := make([][]byte, size)
|
||||
for i := 0; i < size; i++ {
|
||||
keys[i] = string(tx.args[2*i+1])
|
||||
values[i] = tx.args[2*i+2]
|
||||
}
|
||||
for i, key := range keys {
|
||||
value := values[i]
|
||||
cluster.db.Put(key, &db.DataEntity{Data: value})
|
||||
}
|
||||
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
|
||||
func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
argCount := len(args) - 1
|
||||
if argCount%2 != 0 || argCount < 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
|
||||
}
|
||||
argCount := len(args) - 1
|
||||
if argCount%2 != 0 || argCount < 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
|
||||
}
|
||||
|
||||
size := argCount / 2
|
||||
keys := make([]string, size)
|
||||
valueMap := make(map[string]string)
|
||||
for i := 0; i < size; i++ {
|
||||
keys[i] = string(args[2*i+1])
|
||||
valueMap[keys[i]] = string(args[2*i+2])
|
||||
}
|
||||
size := argCount / 2
|
||||
keys := make([]string, size)
|
||||
valueMap := make(map[string]string)
|
||||
for i := 0; i < size; i++ {
|
||||
keys[i] = string(args[2*i+1])
|
||||
valueMap[keys[i]] = string(args[2*i+2])
|
||||
}
|
||||
|
||||
groupMap := cluster.groupBy(keys)
|
||||
if len(groupMap) == 1 { // do fast
|
||||
for peer := range groupMap {
|
||||
return cluster.Relay(peer, c, args)
|
||||
}
|
||||
}
|
||||
groupMap := cluster.groupBy(keys)
|
||||
if len(groupMap) == 1 { // do fast
|
||||
for peer := range groupMap {
|
||||
return cluster.Relay(peer, c, args)
|
||||
}
|
||||
}
|
||||
|
||||
//prepare
|
||||
var errReply redis.Reply
|
||||
txId := cluster.idGenerator.NextId()
|
||||
txIdStr := strconv.FormatInt(txId, 10)
|
||||
rollback := false
|
||||
for peer, group := range groupMap {
|
||||
peerArgs := []string{txIdStr}
|
||||
for _, k := range group {
|
||||
peerArgs = append(peerArgs, k, valueMap[k])
|
||||
}
|
||||
var resp redis.Reply
|
||||
if peer == cluster.self {
|
||||
resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...))
|
||||
} else {
|
||||
resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...))
|
||||
}
|
||||
if reply.IsErrorReply(resp) {
|
||||
errReply = resp
|
||||
rollback = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if rollback {
|
||||
// rollback
|
||||
RequestRollback(cluster, c, txId, groupMap)
|
||||
} else {
|
||||
_, errReply = RequestCommit(cluster, c, txId, groupMap)
|
||||
rollback = errReply != nil
|
||||
}
|
||||
if !rollback {
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
return errReply
|
||||
//prepare
|
||||
var errReply redis.Reply
|
||||
txId := cluster.idGenerator.NextId()
|
||||
txIdStr := strconv.FormatInt(txId, 10)
|
||||
rollback := false
|
||||
for peer, group := range groupMap {
|
||||
peerArgs := []string{txIdStr}
|
||||
for _, k := range group {
|
||||
peerArgs = append(peerArgs, k, valueMap[k])
|
||||
}
|
||||
var resp redis.Reply
|
||||
if peer == cluster.self {
|
||||
resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...))
|
||||
} else {
|
||||
resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...))
|
||||
}
|
||||
if reply.IsErrorReply(resp) {
|
||||
errReply = resp
|
||||
rollback = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if rollback {
|
||||
// rollback
|
||||
RequestRollback(cluster, c, txId, groupMap)
|
||||
} else {
|
||||
_, errReply = RequestCommit(cluster, c, txId, groupMap)
|
||||
rollback = errReply != nil
|
||||
}
|
||||
if !rollback {
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
return errReply
|
||||
|
||||
}
|
||||
|
||||
func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
argCount := len(args) - 1
|
||||
if argCount%2 != 0 || argCount < 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
|
||||
}
|
||||
var peer string
|
||||
size := argCount / 2
|
||||
for i := 0; i < size; i++ {
|
||||
key := string(args[2*i])
|
||||
currentPeer := cluster.peerPicker.Get(key)
|
||||
if peer == "" {
|
||||
peer = currentPeer
|
||||
} else {
|
||||
if peer != currentPeer {
|
||||
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
return cluster.Relay(peer, c, args)
|
||||
argCount := len(args) - 1
|
||||
if argCount%2 != 0 || argCount < 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
|
||||
}
|
||||
var peer string
|
||||
size := argCount / 2
|
||||
for i := 0; i < size; i++ {
|
||||
key := string(args[2*i])
|
||||
currentPeer := cluster.peerPicker.Get(key)
|
||||
if peer == "" {
|
||||
peer = currentPeer
|
||||
} else {
|
||||
if peer != currentPeer {
|
||||
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
|
||||
}
|
||||
}
|
||||
}
|
||||
return cluster.Relay(peer, c, args)
|
||||
}
|
||||
|
@@ -1,40 +1,39 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
)
|
||||
|
||||
// TODO: support multiplex slots
|
||||
func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
|
||||
}
|
||||
src := string(args[1])
|
||||
dest := string(args[2])
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
|
||||
}
|
||||
src := string(args[1])
|
||||
dest := string(args[2])
|
||||
|
||||
srcPeer := cluster.peerPicker.Get(src)
|
||||
destPeer := cluster.peerPicker.Get(dest)
|
||||
srcPeer := cluster.peerPicker.Get(src)
|
||||
destPeer := cluster.peerPicker.Get(dest)
|
||||
|
||||
if srcPeer != destPeer {
|
||||
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
||||
}
|
||||
return cluster.Relay(srcPeer, c, args)
|
||||
if srcPeer != destPeer {
|
||||
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
||||
}
|
||||
return cluster.Relay(srcPeer, c, args)
|
||||
}
|
||||
|
||||
func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
|
||||
}
|
||||
src := string(args[1])
|
||||
dest := string(args[2])
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
|
||||
}
|
||||
src := string(args[1])
|
||||
dest := string(args[2])
|
||||
|
||||
srcPeer := cluster.peerPicker.Get(src)
|
||||
destPeer := cluster.peerPicker.Get(dest)
|
||||
srcPeer := cluster.peerPicker.Get(src)
|
||||
destPeer := cluster.peerPicker.Get(dest)
|
||||
|
||||
if srcPeer != destPeer {
|
||||
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
||||
}
|
||||
return cluster.Relay(srcPeer, c, args)
|
||||
if srcPeer != destPeer {
|
||||
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
||||
}
|
||||
return cluster.Relay(srcPeer, c, args)
|
||||
}
|
||||
|
||||
|
@@ -3,106 +3,106 @@ package cluster
|
||||
import "github.com/HDT3213/godis/src/interface/redis"
|
||||
|
||||
func MakeRouter() map[string]CmdFunc {
|
||||
routerMap := make(map[string]CmdFunc)
|
||||
routerMap["ping"] = Ping
|
||||
routerMap := make(map[string]CmdFunc)
|
||||
routerMap["ping"] = Ping
|
||||
|
||||
routerMap["commit"] = Commit
|
||||
routerMap["rollback"] = Rollback
|
||||
routerMap["del"] = Del
|
||||
routerMap["preparedel"] = PrepareDel
|
||||
routerMap["preparemset"] = PrepareMSet
|
||||
routerMap["commit"] = Commit
|
||||
routerMap["rollback"] = Rollback
|
||||
routerMap["del"] = Del
|
||||
routerMap["preparedel"] = PrepareDel
|
||||
routerMap["preparemset"] = PrepareMSet
|
||||
|
||||
routerMap["expire"] = defaultFunc
|
||||
routerMap["expireat"] = defaultFunc
|
||||
routerMap["pexpire"] = defaultFunc
|
||||
routerMap["pexpireat"] = defaultFunc
|
||||
routerMap["ttl"] = defaultFunc
|
||||
routerMap["pttl"] = defaultFunc
|
||||
routerMap["persist"] = defaultFunc
|
||||
routerMap["exists"] = defaultFunc
|
||||
routerMap["type"] = defaultFunc
|
||||
routerMap["rename"] = Rename
|
||||
routerMap["renamenx"] = RenameNx
|
||||
routerMap["expire"] = defaultFunc
|
||||
routerMap["expireat"] = defaultFunc
|
||||
routerMap["pexpire"] = defaultFunc
|
||||
routerMap["pexpireat"] = defaultFunc
|
||||
routerMap["ttl"] = defaultFunc
|
||||
routerMap["pttl"] = defaultFunc
|
||||
routerMap["persist"] = defaultFunc
|
||||
routerMap["exists"] = defaultFunc
|
||||
routerMap["type"] = defaultFunc
|
||||
routerMap["rename"] = Rename
|
||||
routerMap["renamenx"] = RenameNx
|
||||
|
||||
routerMap["set"] = defaultFunc
|
||||
routerMap["setnx"] = defaultFunc
|
||||
routerMap["setex"] = defaultFunc
|
||||
routerMap["psetex"] = defaultFunc
|
||||
routerMap["mset"] = MSet
|
||||
routerMap["mget"] = MGet
|
||||
routerMap["msetnx"] = MSetNX
|
||||
routerMap["get"] = defaultFunc
|
||||
routerMap["getset"] = defaultFunc
|
||||
routerMap["incr"] = defaultFunc
|
||||
routerMap["incrby"] = defaultFunc
|
||||
routerMap["incrbyfloat"] = defaultFunc
|
||||
routerMap["decr"] = defaultFunc
|
||||
routerMap["decrby"] = defaultFunc
|
||||
routerMap["set"] = defaultFunc
|
||||
routerMap["setnx"] = defaultFunc
|
||||
routerMap["setex"] = defaultFunc
|
||||
routerMap["psetex"] = defaultFunc
|
||||
routerMap["mset"] = MSet
|
||||
routerMap["mget"] = MGet
|
||||
routerMap["msetnx"] = MSetNX
|
||||
routerMap["get"] = defaultFunc
|
||||
routerMap["getset"] = defaultFunc
|
||||
routerMap["incr"] = defaultFunc
|
||||
routerMap["incrby"] = defaultFunc
|
||||
routerMap["incrbyfloat"] = defaultFunc
|
||||
routerMap["decr"] = defaultFunc
|
||||
routerMap["decrby"] = defaultFunc
|
||||
|
||||
routerMap["lpush"] = defaultFunc
|
||||
routerMap["lpushx"] = defaultFunc
|
||||
routerMap["rpush"] = defaultFunc
|
||||
routerMap["rpushx"] = defaultFunc
|
||||
routerMap["lpop"] = defaultFunc
|
||||
routerMap["rpop"] = defaultFunc
|
||||
//routerMap["rpoplpush"] = RPopLPush
|
||||
routerMap["lrem"] = defaultFunc
|
||||
routerMap["llen"] = defaultFunc
|
||||
routerMap["lindex"] = defaultFunc
|
||||
routerMap["lset"] = defaultFunc
|
||||
routerMap["lrange"] = defaultFunc
|
||||
routerMap["lpush"] = defaultFunc
|
||||
routerMap["lpushx"] = defaultFunc
|
||||
routerMap["rpush"] = defaultFunc
|
||||
routerMap["rpushx"] = defaultFunc
|
||||
routerMap["lpop"] = defaultFunc
|
||||
routerMap["rpop"] = defaultFunc
|
||||
//routerMap["rpoplpush"] = RPopLPush
|
||||
routerMap["lrem"] = defaultFunc
|
||||
routerMap["llen"] = defaultFunc
|
||||
routerMap["lindex"] = defaultFunc
|
||||
routerMap["lset"] = defaultFunc
|
||||
routerMap["lrange"] = defaultFunc
|
||||
|
||||
routerMap["hset"] = defaultFunc
|
||||
routerMap["hsetnx"] = defaultFunc
|
||||
routerMap["hget"] = defaultFunc
|
||||
routerMap["hexists"] = defaultFunc
|
||||
routerMap["hdel"] = defaultFunc
|
||||
routerMap["hlen"] = defaultFunc
|
||||
routerMap["hmget"] = defaultFunc
|
||||
routerMap["hmset"] = defaultFunc
|
||||
routerMap["hkeys"] = defaultFunc
|
||||
routerMap["hvals"] = defaultFunc
|
||||
routerMap["hgetall"] = defaultFunc
|
||||
routerMap["hincrby"] = defaultFunc
|
||||
routerMap["hincrbyfloat"] = defaultFunc
|
||||
routerMap["hset"] = defaultFunc
|
||||
routerMap["hsetnx"] = defaultFunc
|
||||
routerMap["hget"] = defaultFunc
|
||||
routerMap["hexists"] = defaultFunc
|
||||
routerMap["hdel"] = defaultFunc
|
||||
routerMap["hlen"] = defaultFunc
|
||||
routerMap["hmget"] = defaultFunc
|
||||
routerMap["hmset"] = defaultFunc
|
||||
routerMap["hkeys"] = defaultFunc
|
||||
routerMap["hvals"] = defaultFunc
|
||||
routerMap["hgetall"] = defaultFunc
|
||||
routerMap["hincrby"] = defaultFunc
|
||||
routerMap["hincrbyfloat"] = defaultFunc
|
||||
|
||||
routerMap["sadd"] = defaultFunc
|
||||
routerMap["sismember"] = defaultFunc
|
||||
routerMap["srem"] = defaultFunc
|
||||
routerMap["scard"] = defaultFunc
|
||||
routerMap["smembers"] = defaultFunc
|
||||
routerMap["sinter"] = defaultFunc
|
||||
routerMap["sinterstore"] = defaultFunc
|
||||
routerMap["sunion"] = defaultFunc
|
||||
routerMap["sunionstore"] = defaultFunc
|
||||
routerMap["sdiff"] = defaultFunc
|
||||
routerMap["sdiffstore"] = defaultFunc
|
||||
routerMap["srandmember"] = defaultFunc
|
||||
routerMap["sadd"] = defaultFunc
|
||||
routerMap["sismember"] = defaultFunc
|
||||
routerMap["srem"] = defaultFunc
|
||||
routerMap["scard"] = defaultFunc
|
||||
routerMap["smembers"] = defaultFunc
|
||||
routerMap["sinter"] = defaultFunc
|
||||
routerMap["sinterstore"] = defaultFunc
|
||||
routerMap["sunion"] = defaultFunc
|
||||
routerMap["sunionstore"] = defaultFunc
|
||||
routerMap["sdiff"] = defaultFunc
|
||||
routerMap["sdiffstore"] = defaultFunc
|
||||
routerMap["srandmember"] = defaultFunc
|
||||
|
||||
routerMap["zadd"] = defaultFunc
|
||||
routerMap["zscore"] = defaultFunc
|
||||
routerMap["zincrby"] = defaultFunc
|
||||
routerMap["zrank"] = defaultFunc
|
||||
routerMap["zcount"] = defaultFunc
|
||||
routerMap["zrevrank"] = defaultFunc
|
||||
routerMap["zcard"] = defaultFunc
|
||||
routerMap["zrange"] = defaultFunc
|
||||
routerMap["zrevrange"] = defaultFunc
|
||||
routerMap["zrangebyscore"] = defaultFunc
|
||||
routerMap["zrevrangebyscore"] = defaultFunc
|
||||
routerMap["zrem"] = defaultFunc
|
||||
routerMap["zremrangebyscore"] = defaultFunc
|
||||
routerMap["zremrangebyrank"] = defaultFunc
|
||||
routerMap["zadd"] = defaultFunc
|
||||
routerMap["zscore"] = defaultFunc
|
||||
routerMap["zincrby"] = defaultFunc
|
||||
routerMap["zrank"] = defaultFunc
|
||||
routerMap["zcount"] = defaultFunc
|
||||
routerMap["zrevrank"] = defaultFunc
|
||||
routerMap["zcard"] = defaultFunc
|
||||
routerMap["zrange"] = defaultFunc
|
||||
routerMap["zrevrange"] = defaultFunc
|
||||
routerMap["zrangebyscore"] = defaultFunc
|
||||
routerMap["zrevrangebyscore"] = defaultFunc
|
||||
routerMap["zrem"] = defaultFunc
|
||||
routerMap["zremrangebyscore"] = defaultFunc
|
||||
routerMap["zremrangebyrank"] = defaultFunc
|
||||
|
||||
//routerMap["flushdb"] = FlushDB
|
||||
//routerMap["flushall"] = FlushAll
|
||||
//routerMap["keys"] = Keys
|
||||
//routerMap["flushdb"] = FlushDB
|
||||
//routerMap["flushall"] = FlushAll
|
||||
//routerMap["keys"] = Keys
|
||||
|
||||
return routerMap
|
||||
return routerMap
|
||||
}
|
||||
|
||||
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
key := string(args[1])
|
||||
peer := cluster.peerPicker.Get(key)
|
||||
return cluster.Relay(peer, c, args)
|
||||
}
|
||||
key := string(args[1])
|
||||
peer := cluster.peerPicker.Get(key)
|
||||
return cluster.Relay(peer, c, args)
|
||||
}
|
||||
|
@@ -1,188 +1,188 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/db"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/lib/marshal/gob"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/db"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/lib/marshal/gob"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Transaction struct {
|
||||
id string // transaction id
|
||||
args [][]byte // cmd args
|
||||
cluster *Cluster
|
||||
conn redis.Connection
|
||||
id string // transaction id
|
||||
args [][]byte // cmd args
|
||||
cluster *Cluster
|
||||
conn redis.Connection
|
||||
|
||||
keys []string // related keys
|
||||
undoLog map[string][]byte // store data for undoLog
|
||||
keys []string // related keys
|
||||
undoLog map[string][]byte // store data for undoLog
|
||||
|
||||
lockUntil time.Time
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
status int8
|
||||
lockUntil time.Time
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
status int8
|
||||
}
|
||||
|
||||
const (
|
||||
maxLockTime = 3 * time.Second
|
||||
maxLockTime = 3 * time.Second
|
||||
|
||||
CreatedStatus = 0
|
||||
PreparedStatus = 1
|
||||
CommitedStatus = 2
|
||||
RollbackedStatus = 3
|
||||
CreatedStatus = 0
|
||||
PreparedStatus = 1
|
||||
CommitedStatus = 2
|
||||
RollbackedStatus = 3
|
||||
)
|
||||
|
||||
func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction {
|
||||
return &Transaction{
|
||||
id: id,
|
||||
args: args,
|
||||
cluster: cluster,
|
||||
conn: c,
|
||||
keys: keys,
|
||||
status: CreatedStatus,
|
||||
}
|
||||
return &Transaction{
|
||||
id: id,
|
||||
args: args,
|
||||
cluster: cluster,
|
||||
conn: c,
|
||||
keys: keys,
|
||||
status: CreatedStatus,
|
||||
}
|
||||
}
|
||||
|
||||
// t should contains Keys field
|
||||
func (tx *Transaction) prepare() error {
|
||||
// lock keys
|
||||
tx.cluster.db.Locks(tx.keys...)
|
||||
// lock keys
|
||||
tx.cluster.db.Locks(tx.keys...)
|
||||
|
||||
// use context to manage
|
||||
//tx.lockUntil = time.Now().Add(maxLockTime)
|
||||
//ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil)
|
||||
//tx.ctx = ctx
|
||||
//tx.cancel = cancel
|
||||
// use context to manage
|
||||
//tx.lockUntil = time.Now().Add(maxLockTime)
|
||||
//ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil)
|
||||
//tx.ctx = ctx
|
||||
//tx.cancel = cancel
|
||||
|
||||
// build undoLog
|
||||
tx.undoLog = make(map[string][]byte)
|
||||
for _, key := range tx.keys {
|
||||
entity, ok := tx.cluster.db.Get(key)
|
||||
if ok {
|
||||
blob, err := gob.Marshal(entity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.undoLog[key] = blob
|
||||
} else {
|
||||
tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback
|
||||
}
|
||||
}
|
||||
tx.status = PreparedStatus
|
||||
return nil
|
||||
// build undoLog
|
||||
tx.undoLog = make(map[string][]byte)
|
||||
for _, key := range tx.keys {
|
||||
entity, ok := tx.cluster.db.Get(key)
|
||||
if ok {
|
||||
blob, err := gob.Marshal(entity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.undoLog[key] = blob
|
||||
} else {
|
||||
tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback
|
||||
}
|
||||
}
|
||||
tx.status = PreparedStatus
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *Transaction) rollback() error {
|
||||
for key, blob := range tx.undoLog {
|
||||
if len(blob) > 0 {
|
||||
entity := &db.DataEntity{}
|
||||
err := gob.UnMarshal(blob, entity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.cluster.db.Put(key, entity)
|
||||
} else {
|
||||
tx.cluster.db.Remove(key)
|
||||
}
|
||||
}
|
||||
if tx.status != CommitedStatus {
|
||||
tx.cluster.db.UnLocks(tx.keys...)
|
||||
}
|
||||
tx.status = RollbackedStatus
|
||||
return nil
|
||||
for key, blob := range tx.undoLog {
|
||||
if len(blob) > 0 {
|
||||
entity := &db.DataEntity{}
|
||||
err := gob.UnMarshal(blob, entity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tx.cluster.db.Put(key, entity)
|
||||
} else {
|
||||
tx.cluster.db.Remove(key)
|
||||
}
|
||||
}
|
||||
if tx.status != CommitedStatus {
|
||||
tx.cluster.db.UnLocks(tx.keys...)
|
||||
}
|
||||
tx.status = RollbackedStatus
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollback local transaction
|
||||
func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
|
||||
}
|
||||
txId := string(args[1])
|
||||
raw, ok := cluster.transactions.Get(txId)
|
||||
if !ok {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
tx, _ := raw.(*Transaction)
|
||||
err := tx.rollback()
|
||||
if err != nil {
|
||||
return reply.MakeErrReply(err.Error())
|
||||
}
|
||||
return reply.MakeIntReply(1)
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
|
||||
}
|
||||
txId := string(args[1])
|
||||
raw, ok := cluster.transactions.Get(txId)
|
||||
if !ok {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
tx, _ := raw.(*Transaction)
|
||||
err := tx.rollback()
|
||||
if err != nil {
|
||||
return reply.MakeErrReply(err.Error())
|
||||
}
|
||||
return reply.MakeIntReply(1)
|
||||
}
|
||||
|
||||
// commit local transaction as a worker
|
||||
func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
|
||||
}
|
||||
txId := string(args[1])
|
||||
raw, ok := cluster.transactions.Get(txId)
|
||||
if !ok {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
tx, _ := raw.(*Transaction)
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
|
||||
}
|
||||
txId := string(args[1])
|
||||
raw, ok := cluster.transactions.Get(txId)
|
||||
if !ok {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
tx, _ := raw.(*Transaction)
|
||||
|
||||
// finish transaction
|
||||
defer func() {
|
||||
cluster.db.UnLocks(tx.keys...)
|
||||
tx.status = CommitedStatus
|
||||
//cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit
|
||||
}()
|
||||
// finish transaction
|
||||
defer func() {
|
||||
cluster.db.UnLocks(tx.keys...)
|
||||
tx.status = CommitedStatus
|
||||
//cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit
|
||||
}()
|
||||
|
||||
cmd := strings.ToLower(string(tx.args[0]))
|
||||
var result redis.Reply
|
||||
if cmd == "del" {
|
||||
result = CommitDel(cluster, c, tx)
|
||||
} else if cmd == "mset" {
|
||||
result = CommitMSet(cluster, c, tx)
|
||||
}
|
||||
cmd := strings.ToLower(string(tx.args[0]))
|
||||
var result redis.Reply
|
||||
if cmd == "del" {
|
||||
result = CommitDel(cluster, c, tx)
|
||||
} else if cmd == "mset" {
|
||||
result = CommitMSet(cluster, c, tx)
|
||||
}
|
||||
|
||||
if reply.IsErrorReply(result) {
|
||||
// failed
|
||||
err2 := tx.rollback()
|
||||
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
|
||||
}
|
||||
if reply.IsErrorReply(result) {
|
||||
// failed
|
||||
err2 := tx.rollback()
|
||||
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
|
||||
}
|
||||
|
||||
return result
|
||||
return result
|
||||
}
|
||||
|
||||
// request all node commit transaction as leader
|
||||
func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) {
|
||||
var errReply reply.ErrorReply
|
||||
txIdStr := strconv.FormatInt(txId, 10)
|
||||
respList := make([]redis.Reply, 0, len(peers))
|
||||
for peer := range peers {
|
||||
var resp redis.Reply
|
||||
if peer == cluster.self {
|
||||
resp = Commit(cluster, c, makeArgs("commit", txIdStr))
|
||||
} else {
|
||||
resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr))
|
||||
}
|
||||
if reply.IsErrorReply(resp) {
|
||||
errReply = resp.(reply.ErrorReply)
|
||||
break
|
||||
}
|
||||
respList = append(respList, resp)
|
||||
}
|
||||
if errReply != nil {
|
||||
RequestRollback(cluster, c, txId, peers)
|
||||
return nil, errReply
|
||||
}
|
||||
return respList, nil
|
||||
var errReply reply.ErrorReply
|
||||
txIdStr := strconv.FormatInt(txId, 10)
|
||||
respList := make([]redis.Reply, 0, len(peers))
|
||||
for peer := range peers {
|
||||
var resp redis.Reply
|
||||
if peer == cluster.self {
|
||||
resp = Commit(cluster, c, makeArgs("commit", txIdStr))
|
||||
} else {
|
||||
resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr))
|
||||
}
|
||||
if reply.IsErrorReply(resp) {
|
||||
errReply = resp.(reply.ErrorReply)
|
||||
break
|
||||
}
|
||||
respList = append(respList, resp)
|
||||
}
|
||||
if errReply != nil {
|
||||
RequestRollback(cluster, c, txId, peers)
|
||||
return nil, errReply
|
||||
}
|
||||
return respList, nil
|
||||
}
|
||||
|
||||
// request all node rollback transaction as leader
|
||||
func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) {
|
||||
txIdStr := strconv.FormatInt(txId, 10)
|
||||
for peer := range peers {
|
||||
if peer == cluster.self {
|
||||
Rollback(cluster, c, makeArgs("rollback", txIdStr))
|
||||
} else {
|
||||
cluster.Relay(peer, c, makeArgs("rollback", txIdStr))
|
||||
}
|
||||
}
|
||||
}
|
||||
txIdStr := strconv.FormatInt(txId, 10)
|
||||
for peer := range peers {
|
||||
if peer == cluster.self {
|
||||
Rollback(cluster, c, makeArgs("rollback", txIdStr))
|
||||
} else {
|
||||
cluster.Relay(peer, c, makeArgs("rollback", txIdStr))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/config"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
RedisServer "github.com/HDT3213/godis/src/redis/server"
|
||||
RedisServer "github.com/HDT3213/godis/src/redis/server"
|
||||
"github.com/HDT3213/godis/src/tcp"
|
||||
"os"
|
||||
)
|
||||
@@ -23,6 +23,6 @@ func main() {
|
||||
})
|
||||
|
||||
tcp.ListenAndServe(&tcp.Config{
|
||||
Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port),
|
||||
}, RedisServer.MakeHandler())
|
||||
Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port),
|
||||
}, RedisServer.MakeHandler())
|
||||
}
|
||||
|
@@ -23,12 +23,12 @@ type PropertyHolder struct {
|
||||
var Properties *PropertyHolder
|
||||
|
||||
func init() {
|
||||
// default config
|
||||
Properties = &PropertyHolder{
|
||||
Bind: "127.0.0.1",
|
||||
Port: 6379,
|
||||
AppendOnly: false,
|
||||
}
|
||||
// default config
|
||||
Properties = &PropertyHolder{
|
||||
Bind: "127.0.0.1",
|
||||
Port: 6379,
|
||||
AppendOnly: false,
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(configFilename string) *PropertyHolder {
|
||||
|
@@ -1,16 +1,16 @@
|
||||
package dict
|
||||
|
||||
type Consumer func(key string, val interface{})bool
|
||||
type Consumer func(key string, val interface{}) bool
|
||||
|
||||
type Dict interface {
|
||||
Get(key string) (val interface{}, exists bool)
|
||||
Len() int
|
||||
Put(key string, val interface{}) (result int)
|
||||
PutIfAbsent(key string, val interface{}) (result int)
|
||||
PutIfExists(key string, val interface{}) (result int)
|
||||
Remove(key string) (result int)
|
||||
ForEach(consumer Consumer)
|
||||
Keys() []string
|
||||
RandomKeys(limit int) []string
|
||||
RandomDistinctKeys(limit int) []string
|
||||
Get(key string) (val interface{}, exists bool)
|
||||
Len() int
|
||||
Put(key string, val interface{}) (result int)
|
||||
PutIfAbsent(key string, val interface{}) (result int)
|
||||
PutIfExists(key string, val interface{}) (result int)
|
||||
Remove(key string) (result int)
|
||||
ForEach(consumer Consumer)
|
||||
Keys() []string
|
||||
RandomKeys(limit int) []string
|
||||
RandomDistinctKeys(limit int) []string
|
||||
}
|
||||
|
@@ -1,244 +1,244 @@
|
||||
package dict
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPut(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func(i int) {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
ret := d.Put(key, i)
|
||||
if ret != 1 { // insert 1
|
||||
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
|
||||
}
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
|
||||
}
|
||||
} else {
|
||||
_, ok := d.Get(key)
|
||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func(i int) {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
ret := d.Put(key, i)
|
||||
if ret != 1 { // insert 1
|
||||
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
|
||||
}
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
|
||||
}
|
||||
} else {
|
||||
_, ok := d.Get(key)
|
||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPutIfAbsent(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func(i int) {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
ret := d.PutIfAbsent(key, i)
|
||||
if ret != 1 { // insert 1
|
||||
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
|
||||
}
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) +
|
||||
", key: " + key)
|
||||
}
|
||||
} else {
|
||||
_, ok := d.Get(key)
|
||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||
}
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func(i int) {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
ret := d.PutIfAbsent(key, i)
|
||||
if ret != 1 { // insert 1
|
||||
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
|
||||
}
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) +
|
||||
", key: " + key)
|
||||
}
|
||||
} else {
|
||||
_, ok := d.Get(key)
|
||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||
}
|
||||
|
||||
// update
|
||||
ret = d.PutIfAbsent(key, i * 10)
|
||||
if ret != 0 { // no update
|
||||
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
val, ok = d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false, key: " + key)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
// update
|
||||
ret = d.PutIfAbsent(key, i*10)
|
||||
if ret != 0 { // no update
|
||||
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
val, ok = d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false, key: " + key)
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPutIfExists(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func(i int) {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
// insert
|
||||
ret := d.PutIfExists(key, i)
|
||||
if ret != 0 { // insert
|
||||
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(count)
|
||||
for i := 0; i < count; i++ {
|
||||
go func(i int) {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
// insert
|
||||
ret := d.PutIfExists(key, i)
|
||||
if ret != 0 { // insert
|
||||
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
|
||||
d.Put(key, i)
|
||||
ret = d.PutIfExists(key, 10 * i)
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != 10 * i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(10 * i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
_, ok := d.Get(key)
|
||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
d.Put(key, i)
|
||||
ret = d.PutIfExists(key, 10*i)
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != 10*i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(10*i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
_, ok := d.Get(key)
|
||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||
}
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
d := MakeConcurrent(0)
|
||||
|
||||
// remove head node
|
||||
for i := 0; i < 100; i++ {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
for i := 0; i < 100; i++ {
|
||||
key := "k" + strconv.Itoa(i)
|
||||
// remove head node
|
||||
for i := 0; i < 100; i++ {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
for i := 0; i < 100; i++ {
|
||||
key := "k" + strconv.Itoa(i)
|
||||
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false")
|
||||
}
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false")
|
||||
}
|
||||
|
||||
ret := d.Remove(key)
|
||||
if ret != 1 {
|
||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key)
|
||||
}
|
||||
_, ok = d.Get(key)
|
||||
if ok {
|
||||
t.Error("remove test failed: expected true, actual false")
|
||||
}
|
||||
ret = d.Remove(key)
|
||||
if ret != 0 {
|
||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
}
|
||||
ret := d.Remove(key)
|
||||
if ret != 1 {
|
||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key)
|
||||
}
|
||||
_, ok = d.Get(key)
|
||||
if ok {
|
||||
t.Error("remove test failed: expected true, actual false")
|
||||
}
|
||||
ret = d.Remove(key)
|
||||
if ret != 0 {
|
||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
}
|
||||
|
||||
// remove tail node
|
||||
d = MakeConcurrent(0)
|
||||
for i := 0; i < 100; i++ {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
for i := 9; i >= 0; i-- {
|
||||
key := "k" + strconv.Itoa(i)
|
||||
// remove tail node
|
||||
d = MakeConcurrent(0)
|
||||
for i := 0; i < 100; i++ {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
for i := 9; i >= 0; i-- {
|
||||
key := "k" + strconv.Itoa(i)
|
||||
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false")
|
||||
}
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false")
|
||||
}
|
||||
|
||||
ret := d.Remove(key)
|
||||
if ret != 1 {
|
||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
_, ok = d.Get(key)
|
||||
if ok {
|
||||
t.Error("remove test failed: expected true, actual false")
|
||||
}
|
||||
ret = d.Remove(key)
|
||||
if ret != 0 {
|
||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
}
|
||||
ret := d.Remove(key)
|
||||
if ret != 1 {
|
||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
_, ok = d.Get(key)
|
||||
if ok {
|
||||
t.Error("remove test failed: expected true, actual false")
|
||||
}
|
||||
ret = d.Remove(key)
|
||||
if ret != 0 {
|
||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
}
|
||||
|
||||
// remove middle node
|
||||
d = MakeConcurrent(0)
|
||||
d.Put("head", 0)
|
||||
for i := 0; i < 10; i++ {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
d.Put("tail", 0)
|
||||
for i := 9; i >= 0; i-- {
|
||||
key := "k" + strconv.Itoa(i)
|
||||
// remove middle node
|
||||
d = MakeConcurrent(0)
|
||||
d.Put("head", 0)
|
||||
for i := 0; i < 10; i++ {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
d.Put("tail", 0)
|
||||
for i := 9; i >= 0; i-- {
|
||||
key := "k" + strconv.Itoa(i)
|
||||
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false")
|
||||
}
|
||||
val, ok := d.Get(key)
|
||||
if ok {
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
} else {
|
||||
t.Error("put test failed: expected true, actual: false")
|
||||
}
|
||||
|
||||
ret := d.Remove(key)
|
||||
if ret != 1 {
|
||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
_, ok = d.Get(key)
|
||||
if ok {
|
||||
t.Error("remove test failed: expected true, actual false")
|
||||
}
|
||||
ret = d.Remove(key)
|
||||
if ret != 0 {
|
||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
}
|
||||
ret := d.Remove(key)
|
||||
if ret != 1 {
|
||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
_, ok = d.Get(key)
|
||||
if ok {
|
||||
t.Error("remove test failed: expected true, actual false")
|
||||
}
|
||||
ret = d.Remove(key)
|
||||
if ret != 0 {
|
||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestForEach(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
size := 100
|
||||
for i := 0; i < size; i++ {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
i := 0
|
||||
d.ForEach(func(key string, value interface{})bool {
|
||||
intVal, _ := value.(int)
|
||||
expectedKey := "k" + strconv.Itoa(intVal)
|
||||
if key != expectedKey {
|
||||
t.Error("remove test failed: expected " + expectedKey + ", actual: " + key)
|
||||
}
|
||||
i++
|
||||
return true
|
||||
})
|
||||
if i != size {
|
||||
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
d := MakeConcurrent(0)
|
||||
size := 100
|
||||
for i := 0; i < size; i++ {
|
||||
// insert
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
i := 0
|
||||
d.ForEach(func(key string, value interface{}) bool {
|
||||
intVal, _ := value.(int)
|
||||
expectedKey := "k" + strconv.Itoa(intVal)
|
||||
if key != expectedKey {
|
||||
t.Error("remove test failed: expected " + expectedKey + ", actual: " + key)
|
||||
}
|
||||
i++
|
||||
return true
|
||||
})
|
||||
if i != size {
|
||||
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
|
@@ -1,108 +1,108 @@
|
||||
package dict
|
||||
|
||||
type SimpleDict struct {
|
||||
m map[string]interface{}
|
||||
m map[string]interface{}
|
||||
}
|
||||
|
||||
func MakeSimple() *SimpleDict {
|
||||
return &SimpleDict{
|
||||
m: make(map[string]interface{}),
|
||||
}
|
||||
return &SimpleDict{
|
||||
m: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) {
|
||||
val, ok := dict.m[key]
|
||||
return val, ok
|
||||
val, ok := dict.m[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) Len() int {
|
||||
if dict.m == nil {
|
||||
panic("m is nil")
|
||||
}
|
||||
return len(dict.m)
|
||||
if dict.m == nil {
|
||||
panic("m is nil")
|
||||
}
|
||||
return len(dict.m)
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) Put(key string, val interface{}) (result int) {
|
||||
_, existed := dict.m[key]
|
||||
dict.m[key] = val
|
||||
if existed {
|
||||
return 0
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
_, existed := dict.m[key]
|
||||
dict.m[key] = val
|
||||
if existed {
|
||||
return 0
|
||||
} else {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) {
|
||||
_, existed := dict.m[key]
|
||||
if existed {
|
||||
return 0
|
||||
} else {
|
||||
dict.m[key] = val
|
||||
return 1
|
||||
}
|
||||
_, existed := dict.m[key]
|
||||
if existed {
|
||||
return 0
|
||||
} else {
|
||||
dict.m[key] = val
|
||||
return 1
|
||||
}
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) {
|
||||
_, existed := dict.m[key]
|
||||
if existed {
|
||||
dict.m[key] = val
|
||||
return 1
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
_, existed := dict.m[key]
|
||||
if existed {
|
||||
dict.m[key] = val
|
||||
return 1
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) Remove(key string) (result int) {
|
||||
_, existed := dict.m[key]
|
||||
delete(dict.m, key)
|
||||
if existed {
|
||||
return 1
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
_, existed := dict.m[key]
|
||||
delete(dict.m, key)
|
||||
if existed {
|
||||
return 1
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) Keys() []string {
|
||||
result := make([]string, len(dict.m))
|
||||
i := 0
|
||||
for k := range dict.m {
|
||||
result[i] = k
|
||||
}
|
||||
return result
|
||||
result := make([]string, len(dict.m))
|
||||
i := 0
|
||||
for k := range dict.m {
|
||||
result[i] = k
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) ForEach(consumer Consumer) {
|
||||
for k, v := range dict.m {
|
||||
if !consumer(k, v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
for k, v := range dict.m {
|
||||
if !consumer(k, v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) RandomKeys(limit int) []string {
|
||||
result := make([]string, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
for k := range dict.m {
|
||||
result[i] = k
|
||||
break
|
||||
}
|
||||
}
|
||||
return result
|
||||
result := make([]string, limit)
|
||||
for i := 0; i < limit; i++ {
|
||||
for k := range dict.m {
|
||||
result[i] = k
|
||||
break
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (dict *SimpleDict) RandomDistinctKeys(limit int) []string {
|
||||
size := limit
|
||||
if size > len(dict.m) {
|
||||
size = len(dict.m)
|
||||
}
|
||||
result := make([]string, size)
|
||||
i := 0
|
||||
for k := range dict.m {
|
||||
if i == limit {
|
||||
break
|
||||
}
|
||||
result[i] = k
|
||||
i++
|
||||
}
|
||||
return result
|
||||
size := limit
|
||||
if size > len(dict.m) {
|
||||
size = len(dict.m)
|
||||
}
|
||||
result := make([]string, size)
|
||||
i := 0
|
||||
for k := range dict.m {
|
||||
if i == limit {
|
||||
break
|
||||
}
|
||||
result[i] = k
|
||||
i++
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
@@ -3,324 +3,324 @@ package list
|
||||
import "github.com/HDT3213/godis/src/datastruct/utils"
|
||||
|
||||
type LinkedList struct {
|
||||
first *node
|
||||
last *node
|
||||
size int
|
||||
first *node
|
||||
last *node
|
||||
size int
|
||||
}
|
||||
|
||||
type node struct {
|
||||
val interface{}
|
||||
prev *node
|
||||
next * node
|
||||
val interface{}
|
||||
prev *node
|
||||
next *node
|
||||
}
|
||||
|
||||
func (list *LinkedList)Add(val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := &node{
|
||||
val: val,
|
||||
}
|
||||
if list.last == nil {
|
||||
// empty list
|
||||
list.first = n
|
||||
list.last = n
|
||||
} else {
|
||||
n.prev = list.last
|
||||
list.last.next = n
|
||||
list.last = n
|
||||
}
|
||||
list.size++
|
||||
func (list *LinkedList) Add(val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := &node{
|
||||
val: val,
|
||||
}
|
||||
if list.last == nil {
|
||||
// empty list
|
||||
list.first = n
|
||||
list.last = n
|
||||
} else {
|
||||
n.prev = list.last
|
||||
list.last.next = n
|
||||
list.last = n
|
||||
}
|
||||
list.size++
|
||||
}
|
||||
|
||||
func (list *LinkedList)find(index int)(n *node) {
|
||||
if index < list.size / 2 {
|
||||
n := list.first
|
||||
for i := 0; i < index; i++ {
|
||||
n = n.next
|
||||
}
|
||||
return n
|
||||
} else {
|
||||
n := list.last
|
||||
for i := list.size - 1; i > index; i-- {
|
||||
n = n.prev
|
||||
}
|
||||
return n
|
||||
}
|
||||
func (list *LinkedList) find(index int) (n *node) {
|
||||
if index < list.size/2 {
|
||||
n := list.first
|
||||
for i := 0; i < index; i++ {
|
||||
n = n.next
|
||||
}
|
||||
return n
|
||||
} else {
|
||||
n := list.last
|
||||
for i := list.size - 1; i > index; i-- {
|
||||
n = n.prev
|
||||
}
|
||||
return n
|
||||
}
|
||||
}
|
||||
|
||||
func (list *LinkedList)Get(index int)(val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if index < 0 || index >= list.size {
|
||||
panic("index out of bound")
|
||||
}
|
||||
return list.find(index).val
|
||||
func (list *LinkedList) Get(index int) (val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if index < 0 || index >= list.size {
|
||||
panic("index out of bound")
|
||||
}
|
||||
return list.find(index).val
|
||||
}
|
||||
|
||||
func (list *LinkedList)Set(index int, val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if index < 0 || index > list.size {
|
||||
panic("index out of bound")
|
||||
}
|
||||
n := list.find(index)
|
||||
n.val = val
|
||||
func (list *LinkedList) Set(index int, val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if index < 0 || index > list.size {
|
||||
panic("index out of bound")
|
||||
}
|
||||
n := list.find(index)
|
||||
n.val = val
|
||||
}
|
||||
|
||||
func (list *LinkedList)Insert(index int, val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if index < 0 || index > list.size {
|
||||
panic("index out of bound")
|
||||
}
|
||||
func (list *LinkedList) Insert(index int, val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if index < 0 || index > list.size {
|
||||
panic("index out of bound")
|
||||
}
|
||||
|
||||
if index == list.size {
|
||||
list.Add(val)
|
||||
return
|
||||
} else {
|
||||
// list is not empty
|
||||
pivot := list.find(index)
|
||||
n := &node{
|
||||
val: val,
|
||||
prev: pivot.prev,
|
||||
next: pivot,
|
||||
}
|
||||
if pivot.prev == nil {
|
||||
list.first = n
|
||||
} else {
|
||||
pivot.prev.next = n
|
||||
}
|
||||
pivot.prev = n
|
||||
list.size++
|
||||
}
|
||||
if index == list.size {
|
||||
list.Add(val)
|
||||
return
|
||||
} else {
|
||||
// list is not empty
|
||||
pivot := list.find(index)
|
||||
n := &node{
|
||||
val: val,
|
||||
prev: pivot.prev,
|
||||
next: pivot,
|
||||
}
|
||||
if pivot.prev == nil {
|
||||
list.first = n
|
||||
} else {
|
||||
pivot.prev.next = n
|
||||
}
|
||||
pivot.prev = n
|
||||
list.size++
|
||||
}
|
||||
}
|
||||
|
||||
func (list *LinkedList)removeNode(n *node) {
|
||||
if n.prev == nil {
|
||||
list.first = n.next
|
||||
} else {
|
||||
n.prev.next = n.next
|
||||
}
|
||||
if n.next == nil {
|
||||
list.last = n.prev
|
||||
} else {
|
||||
n.next.prev = n.prev
|
||||
}
|
||||
func (list *LinkedList) removeNode(n *node) {
|
||||
if n.prev == nil {
|
||||
list.first = n.next
|
||||
} else {
|
||||
n.prev.next = n.next
|
||||
}
|
||||
if n.next == nil {
|
||||
list.last = n.prev
|
||||
} else {
|
||||
n.next.prev = n.prev
|
||||
}
|
||||
|
||||
// for gc
|
||||
n.prev = nil
|
||||
n.next = nil
|
||||
// for gc
|
||||
n.prev = nil
|
||||
n.next = nil
|
||||
|
||||
list.size--
|
||||
list.size--
|
||||
}
|
||||
|
||||
func (list *LinkedList)Remove(index int)(val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if index < 0 || index >= list.size {
|
||||
panic("index out of bound")
|
||||
}
|
||||
func (list *LinkedList) Remove(index int) (val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if index < 0 || index >= list.size {
|
||||
panic("index out of bound")
|
||||
}
|
||||
|
||||
n := list.find(index)
|
||||
list.removeNode(n)
|
||||
return n.val
|
||||
n := list.find(index)
|
||||
list.removeNode(n)
|
||||
return n.val
|
||||
}
|
||||
|
||||
func (list *LinkedList)RemoveLast()(val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if list.last == nil {
|
||||
// empty list
|
||||
return nil
|
||||
}
|
||||
n := list.last
|
||||
list.removeNode(n)
|
||||
return n.val
|
||||
func (list *LinkedList) RemoveLast() (val interface{}) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if list.last == nil {
|
||||
// empty list
|
||||
return nil
|
||||
}
|
||||
n := list.last
|
||||
list.removeNode(n)
|
||||
return n.val
|
||||
}
|
||||
|
||||
func (list *LinkedList)RemoveAllByVal(val interface{})int {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := list.first
|
||||
removed := 0
|
||||
for n != nil {
|
||||
var toRemoveNode *node
|
||||
if utils.Equals(n.val, val) {
|
||||
toRemoveNode = n
|
||||
}
|
||||
if n.next == nil {
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
n = n.next
|
||||
}
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
}
|
||||
return removed
|
||||
func (list *LinkedList) RemoveAllByVal(val interface{}) int {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := list.first
|
||||
removed := 0
|
||||
for n != nil {
|
||||
var toRemoveNode *node
|
||||
if utils.Equals(n.val, val) {
|
||||
toRemoveNode = n
|
||||
}
|
||||
if n.next == nil {
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
n = n.next
|
||||
}
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
/**
|
||||
* remove at most `count` values of the specified value in this list
|
||||
* scan from left to right
|
||||
*/
|
||||
func (list *LinkedList) RemoveByVal(val interface{}, count int)int {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := list.first
|
||||
removed := 0
|
||||
for n != nil {
|
||||
var toRemoveNode *node
|
||||
if utils.Equals(n.val, val) {
|
||||
toRemoveNode = n
|
||||
}
|
||||
if n.next == nil {
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
n = n.next
|
||||
}
|
||||
func (list *LinkedList) RemoveByVal(val interface{}, count int) int {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := list.first
|
||||
removed := 0
|
||||
for n != nil {
|
||||
var toRemoveNode *node
|
||||
if utils.Equals(n.val, val) {
|
||||
toRemoveNode = n
|
||||
}
|
||||
if n.next == nil {
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
n = n.next
|
||||
}
|
||||
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
if removed == count {
|
||||
break
|
||||
}
|
||||
}
|
||||
return removed
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
if removed == count {
|
||||
break
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int)int {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := list.last
|
||||
removed := 0
|
||||
for n != nil {
|
||||
var toRemoveNode *node
|
||||
if utils.Equals(n.val, val) {
|
||||
toRemoveNode = n
|
||||
}
|
||||
if n.prev == nil {
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
n = n.prev
|
||||
}
|
||||
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := list.last
|
||||
removed := 0
|
||||
for n != nil {
|
||||
var toRemoveNode *node
|
||||
if utils.Equals(n.val, val) {
|
||||
toRemoveNode = n
|
||||
}
|
||||
if n.prev == nil {
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
n = n.prev
|
||||
}
|
||||
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
if removed == count {
|
||||
break
|
||||
}
|
||||
}
|
||||
return removed
|
||||
if toRemoveNode != nil {
|
||||
removed++
|
||||
list.removeNode(toRemoveNode)
|
||||
}
|
||||
if removed == count {
|
||||
break
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (list *LinkedList)Len()int {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
return list.size
|
||||
func (list *LinkedList) Len() int {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
return list.size
|
||||
}
|
||||
|
||||
func (list *LinkedList)ForEach(consumer func(int, interface{})bool) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := list.first
|
||||
i := 0
|
||||
for n != nil {
|
||||
goNext := consumer(i, n.val)
|
||||
if !goNext || n.next == nil {
|
||||
break
|
||||
} else {
|
||||
i++
|
||||
n = n.next
|
||||
}
|
||||
}
|
||||
func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
n := list.first
|
||||
i := 0
|
||||
for n != nil {
|
||||
goNext := consumer(i, n.val)
|
||||
if !goNext || n.next == nil {
|
||||
break
|
||||
} else {
|
||||
i++
|
||||
n = n.next
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (list *LinkedList)Contains(val interface{})bool {
|
||||
contains := false
|
||||
list.ForEach(func(i int, actual interface{}) bool {
|
||||
if actual == val {
|
||||
contains = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return contains
|
||||
func (list *LinkedList) Contains(val interface{}) bool {
|
||||
contains := false
|
||||
list.ForEach(func(i int, actual interface{}) bool {
|
||||
if actual == val {
|
||||
contains = true
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
return contains
|
||||
}
|
||||
|
||||
func (list *LinkedList)Range(start int, stop int)[]interface{} {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if start < 0 || start >= list.size {
|
||||
panic("`start` out of range")
|
||||
}
|
||||
if stop < start || stop > list.size {
|
||||
panic("`stop` out of range")
|
||||
}
|
||||
func (list *LinkedList) Range(start int, stop int) []interface{} {
|
||||
if list == nil {
|
||||
panic("list is nil")
|
||||
}
|
||||
if start < 0 || start >= list.size {
|
||||
panic("`start` out of range")
|
||||
}
|
||||
if stop < start || stop > list.size {
|
||||
panic("`stop` out of range")
|
||||
}
|
||||
|
||||
sliceSize := stop - start
|
||||
slice := make([]interface{}, sliceSize)
|
||||
n := list.first
|
||||
i := 0
|
||||
sliceIndex := 0
|
||||
for n != nil {
|
||||
if i >= start && i < stop {
|
||||
slice[sliceIndex] = n.val
|
||||
sliceIndex++
|
||||
} else if i >= stop {
|
||||
break
|
||||
}
|
||||
if n.next == nil {
|
||||
break
|
||||
} else {
|
||||
i++
|
||||
n = n.next
|
||||
}
|
||||
}
|
||||
return slice
|
||||
sliceSize := stop - start
|
||||
slice := make([]interface{}, sliceSize)
|
||||
n := list.first
|
||||
i := 0
|
||||
sliceIndex := 0
|
||||
for n != nil {
|
||||
if i >= start && i < stop {
|
||||
slice[sliceIndex] = n.val
|
||||
sliceIndex++
|
||||
} else if i >= stop {
|
||||
break
|
||||
}
|
||||
if n.next == nil {
|
||||
break
|
||||
} else {
|
||||
i++
|
||||
n = n.next
|
||||
}
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
func Make(vals ...interface{}) *LinkedList {
|
||||
list := LinkedList{}
|
||||
for _, v := range vals {
|
||||
list.Add(v)
|
||||
}
|
||||
return &list
|
||||
list := LinkedList{}
|
||||
for _, v := range vals {
|
||||
list.Add(v)
|
||||
}
|
||||
return &list
|
||||
}
|
||||
|
||||
func MakeBytesList(vals ...[]byte) *LinkedList {
|
||||
list := LinkedList{}
|
||||
for _, v := range vals {
|
||||
list.Add(v)
|
||||
}
|
||||
return &list
|
||||
}
|
||||
list := LinkedList{}
|
||||
for _, v := range vals {
|
||||
list.Add(v)
|
||||
}
|
||||
return &list
|
||||
}
|
||||
|
@@ -1,215 +1,215 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"strconv"
|
||||
"strings"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ToString(list *LinkedList) string {
|
||||
arr := make([]string, list.size)
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
integer, _ := v.(int)
|
||||
arr[i] = strconv.Itoa(integer)
|
||||
return true
|
||||
})
|
||||
return "[" + strings.Join(arr, ", ") + "]"
|
||||
arr := make([]string, list.size)
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
integer, _ := v.(int)
|
||||
arr[i] = strconv.Itoa(integer)
|
||||
return true
|
||||
})
|
||||
return "[" + strings.Join(arr, ", ") + "]"
|
||||
}
|
||||
|
||||
func TestAdd(t *testing.T) {
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal != i {
|
||||
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
return true
|
||||
})
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal != i {
|
||||
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
v := list.Get(i)
|
||||
k, _ := v.(int)
|
||||
if i != k {
|
||||
t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k))
|
||||
}
|
||||
}
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
v := list.Get(i)
|
||||
k, _ := v.(int)
|
||||
if i != k {
|
||||
t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 9; i >= 0; i-- {
|
||||
list.Remove(i)
|
||||
if i != list.Len() {
|
||||
t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len()))
|
||||
}
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal != i {
|
||||
t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 9; i >= 0; i-- {
|
||||
list.Remove(i)
|
||||
if i != list.Len() {
|
||||
t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len()))
|
||||
}
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal != i {
|
||||
t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveVal(t *testing.T) {
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
list.Add(i)
|
||||
}
|
||||
for index := 0; index < list.Len(); index++ {
|
||||
list.RemoveAllByVal(index)
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal == index {
|
||||
t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
list.Add(i)
|
||||
}
|
||||
for index := 0; index < list.Len(); index++ {
|
||||
list.RemoveAllByVal(index)
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal == index {
|
||||
t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
list = Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
list.RemoveByVal(i, 1)
|
||||
}
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal != i {
|
||||
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
return true
|
||||
})
|
||||
for i := 0; i < 10; i++ {
|
||||
list.RemoveByVal(i, 1)
|
||||
}
|
||||
if list.Len() != 0 {
|
||||
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
|
||||
}
|
||||
list = Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
list.RemoveByVal(i, 1)
|
||||
}
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal != i {
|
||||
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
return true
|
||||
})
|
||||
for i := 0; i < 10; i++ {
|
||||
list.RemoveByVal(i, 1)
|
||||
}
|
||||
if list.Len() != 0 {
|
||||
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
|
||||
}
|
||||
|
||||
list = Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
list.ReverseRemoveByVal(i, 1)
|
||||
}
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal != i {
|
||||
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
return true
|
||||
})
|
||||
for i := 0; i < 10; i++ {
|
||||
list.ReverseRemoveByVal(i, 1)
|
||||
}
|
||||
if list.Len() != 0 {
|
||||
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
|
||||
}
|
||||
list = Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
list.ReverseRemoveByVal(i, 1)
|
||||
}
|
||||
list.ForEach(func(i int, v interface{}) bool {
|
||||
intVal, _ := v.(int)
|
||||
if intVal != i {
|
||||
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
return true
|
||||
})
|
||||
for i := 0; i < 10; i++ {
|
||||
list.ReverseRemoveByVal(i, 1)
|
||||
}
|
||||
if list.Len() != 0 {
|
||||
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestInsert(t *testing.T) {
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Insert(i*2, i)
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Insert(i*2, i)
|
||||
|
||||
list.ForEach(func(j int, v interface{}) bool {
|
||||
var expected int
|
||||
if j < (i + 1) * 2 {
|
||||
if j%2 == 0 {
|
||||
expected = j / 2
|
||||
} else {
|
||||
expected = (j - 1) / 2
|
||||
}
|
||||
} else {
|
||||
expected = j - i - 1
|
||||
}
|
||||
actual, _ := list.Get(j).(int)
|
||||
if actual != expected {
|
||||
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
|
||||
}
|
||||
return true
|
||||
})
|
||||
list.ForEach(func(j int, v interface{}) bool {
|
||||
var expected int
|
||||
if j < (i+1)*2 {
|
||||
if j%2 == 0 {
|
||||
expected = j / 2
|
||||
} else {
|
||||
expected = (j - 1) / 2
|
||||
}
|
||||
} else {
|
||||
expected = j - i - 1
|
||||
}
|
||||
actual, _ := list.Get(j).(int)
|
||||
if actual != expected {
|
||||
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
for j := 0; j < list.Len(); j++ {
|
||||
var expected int
|
||||
if j < (i + 1) * 2 {
|
||||
if j%2 == 0 {
|
||||
expected = j / 2
|
||||
} else {
|
||||
expected = (j - 1) / 2
|
||||
}
|
||||
} else {
|
||||
expected = j - i - 1
|
||||
}
|
||||
actual, _ := list.Get(j).(int)
|
||||
if actual != expected {
|
||||
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
|
||||
}
|
||||
}
|
||||
for j := 0; j < list.Len(); j++ {
|
||||
var expected int
|
||||
if j < (i+1)*2 {
|
||||
if j%2 == 0 {
|
||||
expected = j / 2
|
||||
} else {
|
||||
expected = (j - 1) / 2
|
||||
}
|
||||
} else {
|
||||
expected = j - i - 1
|
||||
}
|
||||
actual, _ := list.Get(j).(int)
|
||||
if actual != expected {
|
||||
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveLast(t *testing.T) {
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 9; i >= 0; i-- {
|
||||
val := list.RemoveLast()
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
}
|
||||
list := Make()
|
||||
for i := 0; i < 10; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for i := 9; i >= 0; i-- {
|
||||
val := list.RemoveLast()
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRange(t *testing.T) {
|
||||
list := Make()
|
||||
size := 10
|
||||
for i := 0; i < size; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for start := 0; start < size; start++ {
|
||||
for stop := start; stop < size; stop++ {
|
||||
slice := list.Range(start, stop)
|
||||
if len(slice) != stop - start {
|
||||
t.Error("expected " + strconv.Itoa(stop - start) + ", get: " + strconv.Itoa(len(slice)) +
|
||||
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
|
||||
}
|
||||
sliceIndex := 0
|
||||
for i := start; i < stop; i++ {
|
||||
val := slice[sliceIndex]
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) +
|
||||
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
|
||||
}
|
||||
sliceIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
list := Make()
|
||||
size := 10
|
||||
for i := 0; i < size; i++ {
|
||||
list.Add(i)
|
||||
}
|
||||
for start := 0; start < size; start++ {
|
||||
for stop := start; stop < size; stop++ {
|
||||
slice := list.Range(start, stop)
|
||||
if len(slice) != stop-start {
|
||||
t.Error("expected " + strconv.Itoa(stop-start) + ", get: " + strconv.Itoa(len(slice)) +
|
||||
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
|
||||
}
|
||||
sliceIndex := 0
|
||||
for i := start; i < stop; i++ {
|
||||
val := slice[sliceIndex]
|
||||
intVal, _ := val.(int)
|
||||
if intVal != i {
|
||||
t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) +
|
||||
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
|
||||
}
|
||||
sliceIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,153 +1,152 @@
|
||||
package lock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
prime32 = uint32(16777619)
|
||||
prime32 = uint32(16777619)
|
||||
)
|
||||
|
||||
type Locks struct {
|
||||
table []*sync.RWMutex
|
||||
table []*sync.RWMutex
|
||||
}
|
||||
|
||||
func Make(tableSize int) *Locks {
|
||||
table := make([]*sync.RWMutex, tableSize)
|
||||
for i := 0; i < tableSize; i++ {
|
||||
table[i] = &sync.RWMutex{}
|
||||
}
|
||||
return &Locks{
|
||||
table: table,
|
||||
}
|
||||
table := make([]*sync.RWMutex, tableSize)
|
||||
for i := 0; i < tableSize; i++ {
|
||||
table[i] = &sync.RWMutex{}
|
||||
}
|
||||
return &Locks{
|
||||
table: table,
|
||||
}
|
||||
}
|
||||
|
||||
func fnv32(key string) uint32 {
|
||||
hash := uint32(2166136261)
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash *= prime32
|
||||
hash ^= uint32(key[i])
|
||||
}
|
||||
return hash
|
||||
hash := uint32(2166136261)
|
||||
for i := 0; i < len(key); i++ {
|
||||
hash *= prime32
|
||||
hash ^= uint32(key[i])
|
||||
}
|
||||
return hash
|
||||
}
|
||||
|
||||
func (locks *Locks) spread(hashCode uint32) uint32 {
|
||||
if locks == nil {
|
||||
panic("dict is nil")
|
||||
}
|
||||
tableSize := uint32(len(locks.table))
|
||||
return (tableSize - 1) & uint32(hashCode)
|
||||
if locks == nil {
|
||||
panic("dict is nil")
|
||||
}
|
||||
tableSize := uint32(len(locks.table))
|
||||
return (tableSize - 1) & uint32(hashCode)
|
||||
}
|
||||
|
||||
func (locks *Locks)Lock(key string) {
|
||||
index := locks.spread(fnv32(key))
|
||||
mu := locks.table[index]
|
||||
mu.Lock()
|
||||
func (locks *Locks) Lock(key string) {
|
||||
index := locks.spread(fnv32(key))
|
||||
mu := locks.table[index]
|
||||
mu.Lock()
|
||||
}
|
||||
|
||||
func (locks *Locks)RLock(key string) {
|
||||
index := locks.spread(fnv32(key))
|
||||
mu := locks.table[index]
|
||||
mu.RLock()
|
||||
func (locks *Locks) RLock(key string) {
|
||||
index := locks.spread(fnv32(key))
|
||||
mu := locks.table[index]
|
||||
mu.RLock()
|
||||
}
|
||||
|
||||
func (locks *Locks)UnLock(key string) {
|
||||
index := locks.spread(fnv32(key))
|
||||
mu := locks.table[index]
|
||||
mu.Unlock()
|
||||
func (locks *Locks) UnLock(key string) {
|
||||
index := locks.spread(fnv32(key))
|
||||
mu := locks.table[index]
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
func (locks *Locks)RUnLock(key string) {
|
||||
index := locks.spread(fnv32(key))
|
||||
mu := locks.table[index]
|
||||
mu.RUnlock()
|
||||
func (locks *Locks) RUnLock(key string) {
|
||||
index := locks.spread(fnv32(key))
|
||||
mu := locks.table[index]
|
||||
mu.RUnlock()
|
||||
}
|
||||
|
||||
func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 {
|
||||
indexMap := make(map[uint32]bool)
|
||||
for _, key := range keys {
|
||||
index := locks.spread(fnv32(key))
|
||||
indexMap[index] = true
|
||||
}
|
||||
indices := make([]uint32, 0, len(indexMap))
|
||||
for index := range indexMap {
|
||||
indices = append(indices, index)
|
||||
}
|
||||
sort.Slice(indices, func(i, j int) bool {
|
||||
if !reverse {
|
||||
return indices[i] < indices[j]
|
||||
} else {
|
||||
return indices[i] > indices[j]
|
||||
}
|
||||
})
|
||||
return indices
|
||||
indexMap := make(map[uint32]bool)
|
||||
for _, key := range keys {
|
||||
index := locks.spread(fnv32(key))
|
||||
indexMap[index] = true
|
||||
}
|
||||
indices := make([]uint32, 0, len(indexMap))
|
||||
for index := range indexMap {
|
||||
indices = append(indices, index)
|
||||
}
|
||||
sort.Slice(indices, func(i, j int) bool {
|
||||
if !reverse {
|
||||
return indices[i] < indices[j]
|
||||
} else {
|
||||
return indices[i] > indices[j]
|
||||
}
|
||||
})
|
||||
return indices
|
||||
}
|
||||
|
||||
func (locks *Locks)Locks(keys ...string) {
|
||||
indices := locks.toLockIndices(keys, false)
|
||||
for _, index := range indices {
|
||||
mu := locks.table[index]
|
||||
mu.Lock()
|
||||
}
|
||||
func (locks *Locks) Locks(keys ...string) {
|
||||
indices := locks.toLockIndices(keys, false)
|
||||
for _, index := range indices {
|
||||
mu := locks.table[index]
|
||||
mu.Lock()
|
||||
}
|
||||
}
|
||||
|
||||
func (locks *Locks)RLocks(keys ...string) {
|
||||
indices := locks.toLockIndices(keys, false)
|
||||
for _, index := range indices {
|
||||
mu := locks.table[index]
|
||||
mu.RLock()
|
||||
}
|
||||
func (locks *Locks) RLocks(keys ...string) {
|
||||
indices := locks.toLockIndices(keys, false)
|
||||
for _, index := range indices {
|
||||
mu := locks.table[index]
|
||||
mu.RLock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func (locks *Locks)UnLocks(keys ...string) {
|
||||
indices := locks.toLockIndices(keys, true)
|
||||
for _, index := range indices {
|
||||
mu := locks.table[index]
|
||||
mu.Unlock()
|
||||
}
|
||||
func (locks *Locks) UnLocks(keys ...string) {
|
||||
indices := locks.toLockIndices(keys, true)
|
||||
for _, index := range indices {
|
||||
mu := locks.table[index]
|
||||
mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (locks *Locks)RUnLocks(keys ...string) {
|
||||
indices := locks.toLockIndices(keys, true)
|
||||
for _, index := range indices {
|
||||
mu := locks.table[index]
|
||||
mu.RUnlock()
|
||||
}
|
||||
func (locks *Locks) RUnLocks(keys ...string) {
|
||||
indices := locks.toLockIndices(keys, true)
|
||||
for _, index := range indices {
|
||||
mu := locks.table[index]
|
||||
mu.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func GoID() int {
|
||||
var buf [64]byte
|
||||
n := runtime.Stack(buf[:], false)
|
||||
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
|
||||
id, err := strconv.Atoi(idField)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
|
||||
}
|
||||
return id
|
||||
var buf [64]byte
|
||||
n := runtime.Stack(buf[:], false)
|
||||
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
|
||||
id, err := strconv.Atoi(idField)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func debug(testing.T) {
|
||||
lm := Locks{}
|
||||
size := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(size)
|
||||
for i := 0; i < size; i++ {
|
||||
go func(i int) {
|
||||
lm.Locks("1", "2")
|
||||
println("go: " + strconv.Itoa(GoID()))
|
||||
time.Sleep(time.Second)
|
||||
println("go: " + strconv.Itoa(GoID()))
|
||||
lm.UnLocks("1", "2")
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
lm := Locks{}
|
||||
size := 10
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(size)
|
||||
for i := 0; i < size; i++ {
|
||||
go func(i int) {
|
||||
lm.Locks("1", "2")
|
||||
println("go: " + strconv.Itoa(GoID()))
|
||||
time.Sleep(time.Second)
|
||||
println("go: " + strconv.Itoa(GoID()))
|
||||
lm.UnLocks("1", "2")
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
@@ -1,32 +1,32 @@
|
||||
package set
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
size := 10
|
||||
set := Make()
|
||||
for i := 0; i < size; i++ {
|
||||
set.Add(strconv.Itoa(i))
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
ok := set.Has(strconv.Itoa(i))
|
||||
if !ok {
|
||||
t.Error("expected true actual false, key: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
ok := set.Remove(strconv.Itoa(i))
|
||||
if ok != 1 {
|
||||
t.Error("expected true actual false, key: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
ok := set.Has(strconv.Itoa(i))
|
||||
if ok {
|
||||
t.Error("expected false actual true, key: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
size := 10
|
||||
set := Make()
|
||||
for i := 0; i < size; i++ {
|
||||
set.Add(strconv.Itoa(i))
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
ok := set.Has(strconv.Itoa(i))
|
||||
if !ok {
|
||||
t.Error("expected true actual false, key: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
ok := set.Remove(strconv.Itoa(i))
|
||||
if ok != 1 {
|
||||
t.Error("expected true actual false, key: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
for i := 0; i < size; i++ {
|
||||
ok := set.Has(strconv.Itoa(i))
|
||||
if ok {
|
||||
t.Error("expected false actual true, key: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,8 +1,8 @@
|
||||
package sortedset
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"errors"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -14,78 +14,78 @@ import (
|
||||
*/
|
||||
|
||||
const (
|
||||
negativeInf int8 = -1
|
||||
positiveInf int8 = 1
|
||||
negativeInf int8 = -1
|
||||
positiveInf int8 = 1
|
||||
)
|
||||
|
||||
type ScoreBorder struct {
|
||||
Inf int8
|
||||
Value float64
|
||||
Exclude bool
|
||||
Inf int8
|
||||
Value float64
|
||||
Exclude bool
|
||||
}
|
||||
|
||||
// if max.greater(score) then the score is within the upper border
|
||||
// do not use min.greater()
|
||||
func (border *ScoreBorder)greater(value float64)bool {
|
||||
if border.Inf == negativeInf {
|
||||
return false
|
||||
} else if border.Inf == positiveInf {
|
||||
return true
|
||||
}
|
||||
if border.Exclude {
|
||||
return border.Value > value
|
||||
} else {
|
||||
return border.Value >= value
|
||||
}
|
||||
func (border *ScoreBorder) greater(value float64) bool {
|
||||
if border.Inf == negativeInf {
|
||||
return false
|
||||
} else if border.Inf == positiveInf {
|
||||
return true
|
||||
}
|
||||
if border.Exclude {
|
||||
return border.Value > value
|
||||
} else {
|
||||
return border.Value >= value
|
||||
}
|
||||
}
|
||||
|
||||
func (border *ScoreBorder)less(value float64)bool {
|
||||
if border.Inf == negativeInf {
|
||||
return true
|
||||
} else if border.Inf == positiveInf {
|
||||
return false
|
||||
}
|
||||
if border.Exclude {
|
||||
return border.Value < value
|
||||
} else {
|
||||
return border.Value <= value
|
||||
}
|
||||
func (border *ScoreBorder) less(value float64) bool {
|
||||
if border.Inf == negativeInf {
|
||||
return true
|
||||
} else if border.Inf == positiveInf {
|
||||
return false
|
||||
}
|
||||
if border.Exclude {
|
||||
return border.Value < value
|
||||
} else {
|
||||
return border.Value <= value
|
||||
}
|
||||
}
|
||||
|
||||
var positiveInfBorder = &ScoreBorder {
|
||||
Inf: positiveInf,
|
||||
var positiveInfBorder = &ScoreBorder{
|
||||
Inf: positiveInf,
|
||||
}
|
||||
|
||||
var negativeInfBorder = &ScoreBorder {
|
||||
Inf: negativeInf,
|
||||
var negativeInfBorder = &ScoreBorder{
|
||||
Inf: negativeInf,
|
||||
}
|
||||
|
||||
func ParseScoreBorder(s string)(*ScoreBorder, error) {
|
||||
if s == "inf" || s == "+inf" {
|
||||
return positiveInfBorder, nil
|
||||
}
|
||||
if s == "-inf" {
|
||||
return negativeInfBorder, nil
|
||||
}
|
||||
if s[0] == '(' {
|
||||
value, err := strconv.ParseFloat(s[1:], 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("ERR min or max is not a float")
|
||||
}
|
||||
return &ScoreBorder{
|
||||
Inf: 0,
|
||||
Value: value,
|
||||
Exclude: true,
|
||||
}, nil
|
||||
} else {
|
||||
value, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("ERR min or max is not a float")
|
||||
}
|
||||
return &ScoreBorder{
|
||||
Inf: 0,
|
||||
Value: value,
|
||||
Exclude: false,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
func ParseScoreBorder(s string) (*ScoreBorder, error) {
|
||||
if s == "inf" || s == "+inf" {
|
||||
return positiveInfBorder, nil
|
||||
}
|
||||
if s == "-inf" {
|
||||
return negativeInfBorder, nil
|
||||
}
|
||||
if s[0] == '(' {
|
||||
value, err := strconv.ParseFloat(s[1:], 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("ERR min or max is not a float")
|
||||
}
|
||||
return &ScoreBorder{
|
||||
Inf: 0,
|
||||
Value: value,
|
||||
Exclude: true,
|
||||
}, nil
|
||||
} else {
|
||||
value, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("ERR min or max is not a float")
|
||||
}
|
||||
return &ScoreBorder{
|
||||
Inf: 0,
|
||||
Value: value,
|
||||
Exclude: false,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
@@ -3,130 +3,129 @@ package sortedset
|
||||
import "math/rand"
|
||||
|
||||
const (
|
||||
maxLevel = 16
|
||||
maxLevel = 16
|
||||
)
|
||||
|
||||
|
||||
type Element struct {
|
||||
Member string
|
||||
Score float64
|
||||
Member string
|
||||
Score float64
|
||||
}
|
||||
|
||||
// level aspect of a Node
|
||||
type Level struct {
|
||||
forward *Node // forward node has greater score
|
||||
span int64
|
||||
forward *Node // forward node has greater score
|
||||
span int64
|
||||
}
|
||||
|
||||
type Node struct {
|
||||
Element
|
||||
backward *Node
|
||||
level []*Level // level[0] is base level
|
||||
Element
|
||||
backward *Node
|
||||
level []*Level // level[0] is base level
|
||||
}
|
||||
|
||||
type skiplist struct {
|
||||
header *Node
|
||||
tail *Node
|
||||
length int64
|
||||
level int16
|
||||
header *Node
|
||||
tail *Node
|
||||
length int64
|
||||
level int16
|
||||
}
|
||||
|
||||
func makeNode(level int16, score float64, member string)*Node {
|
||||
n := &Node{
|
||||
Element: Element{
|
||||
Score: score,
|
||||
Member: member,
|
||||
},
|
||||
level: make([]*Level, level),
|
||||
}
|
||||
for i := range n.level {
|
||||
n.level[i] = new(Level)
|
||||
}
|
||||
return n
|
||||
func makeNode(level int16, score float64, member string) *Node {
|
||||
n := &Node{
|
||||
Element: Element{
|
||||
Score: score,
|
||||
Member: member,
|
||||
},
|
||||
level: make([]*Level, level),
|
||||
}
|
||||
for i := range n.level {
|
||||
n.level[i] = new(Level)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func makeSkiplist()*skiplist {
|
||||
return &skiplist{
|
||||
level: 1,
|
||||
header: makeNode(maxLevel, 0, ""),
|
||||
}
|
||||
func makeSkiplist() *skiplist {
|
||||
return &skiplist{
|
||||
level: 1,
|
||||
header: makeNode(maxLevel, 0, ""),
|
||||
}
|
||||
}
|
||||
|
||||
func randomLevel() int16 {
|
||||
level := int16(1)
|
||||
for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) {
|
||||
level++
|
||||
}
|
||||
if level < maxLevel {
|
||||
return level
|
||||
}
|
||||
return maxLevel
|
||||
level := int16(1)
|
||||
for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) {
|
||||
level++
|
||||
}
|
||||
if level < maxLevel {
|
||||
return level
|
||||
}
|
||||
return maxLevel
|
||||
}
|
||||
|
||||
func (skiplist *skiplist)insert(member string, score float64)*Node {
|
||||
update := make([]*Node, maxLevel) // link new node with node in `update`
|
||||
rank := make([]int64, maxLevel)
|
||||
func (skiplist *skiplist) insert(member string, score float64) *Node {
|
||||
update := make([]*Node, maxLevel) // link new node with node in `update`
|
||||
rank := make([]int64, maxLevel)
|
||||
|
||||
// find position to insert
|
||||
node := skiplist.header
|
||||
for i := skiplist.level - 1; i >= 0; i-- {
|
||||
if i == skiplist.level - 1 {
|
||||
rank[i] = 0
|
||||
} else {
|
||||
rank[i] = rank[i + 1] // store rank that is crossed to reach the insert position
|
||||
}
|
||||
if node.level[i] != nil {
|
||||
// traverse the skip list
|
||||
for node.level[i].forward != nil &&
|
||||
(node.level[i].forward.Score < score ||
|
||||
(node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key
|
||||
rank[i] += node.level[i].span
|
||||
node = node.level[i].forward
|
||||
}
|
||||
}
|
||||
update[i] = node
|
||||
}
|
||||
// find position to insert
|
||||
node := skiplist.header
|
||||
for i := skiplist.level - 1; i >= 0; i-- {
|
||||
if i == skiplist.level-1 {
|
||||
rank[i] = 0
|
||||
} else {
|
||||
rank[i] = rank[i+1] // store rank that is crossed to reach the insert position
|
||||
}
|
||||
if node.level[i] != nil {
|
||||
// traverse the skip list
|
||||
for node.level[i].forward != nil &&
|
||||
(node.level[i].forward.Score < score ||
|
||||
(node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key
|
||||
rank[i] += node.level[i].span
|
||||
node = node.level[i].forward
|
||||
}
|
||||
}
|
||||
update[i] = node
|
||||
}
|
||||
|
||||
level := randomLevel()
|
||||
// extend skiplist level
|
||||
if level > skiplist.level {
|
||||
for i := skiplist.level; i < level; i++ {
|
||||
rank[i] = 0
|
||||
update[i] = skiplist.header
|
||||
update[i].level[i].span = skiplist.length
|
||||
}
|
||||
skiplist.level = level
|
||||
}
|
||||
level := randomLevel()
|
||||
// extend skiplist level
|
||||
if level > skiplist.level {
|
||||
for i := skiplist.level; i < level; i++ {
|
||||
rank[i] = 0
|
||||
update[i] = skiplist.header
|
||||
update[i].level[i].span = skiplist.length
|
||||
}
|
||||
skiplist.level = level
|
||||
}
|
||||
|
||||
// make node and link into skiplist
|
||||
node = makeNode(level, score, member)
|
||||
for i := int16(0); i < level; i++ {
|
||||
node.level[i].forward = update[i].level[i].forward
|
||||
update[i].level[i].forward = node
|
||||
// make node and link into skiplist
|
||||
node = makeNode(level, score, member)
|
||||
for i := int16(0); i < level; i++ {
|
||||
node.level[i].forward = update[i].level[i].forward
|
||||
update[i].level[i].forward = node
|
||||
|
||||
// update span covered by update[i] as node is inserted here
|
||||
node.level[i].span = update[i].level[i].span - (rank[0] - rank[i])
|
||||
update[i].level[i].span = (rank[0] - rank[i]) + 1
|
||||
}
|
||||
// update span covered by update[i] as node is inserted here
|
||||
node.level[i].span = update[i].level[i].span - (rank[0] - rank[i])
|
||||
update[i].level[i].span = (rank[0] - rank[i]) + 1
|
||||
}
|
||||
|
||||
// increment span for untouched levels
|
||||
for i := level; i < skiplist.level; i++ {
|
||||
update[i].level[i].span++
|
||||
}
|
||||
// increment span for untouched levels
|
||||
for i := level; i < skiplist.level; i++ {
|
||||
update[i].level[i].span++
|
||||
}
|
||||
|
||||
// set backward node
|
||||
if update[0] == skiplist.header {
|
||||
node.backward = nil
|
||||
} else {
|
||||
node.backward = update[0]
|
||||
}
|
||||
if node.level[0].forward != nil {
|
||||
node.level[0].forward.backward = node
|
||||
} else {
|
||||
skiplist.tail = node
|
||||
}
|
||||
skiplist.length++
|
||||
return node
|
||||
// set backward node
|
||||
if update[0] == skiplist.header {
|
||||
node.backward = nil
|
||||
} else {
|
||||
node.backward = update[0]
|
||||
}
|
||||
if node.level[0].forward != nil {
|
||||
node.level[0].forward.backward = node
|
||||
} else {
|
||||
skiplist.tail = node
|
||||
}
|
||||
skiplist.length++
|
||||
return node
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -134,212 +133,212 @@ func (skiplist *skiplist)insert(member string, score float64)*Node {
|
||||
* param update: backward node (of target)
|
||||
*/
|
||||
func (skiplist *skiplist) removeNode(node *Node, update []*Node) {
|
||||
for i := int16(0); i < skiplist.level; i++ {
|
||||
if update[i].level[i].forward == node {
|
||||
update[i].level[i].span += node.level[i].span - 1
|
||||
update[i].level[i].forward = node.level[i].forward
|
||||
} else {
|
||||
update[i].level[i].span--
|
||||
}
|
||||
}
|
||||
if node.level[0].forward != nil {
|
||||
node.level[0].forward.backward = node.backward
|
||||
} else {
|
||||
skiplist.tail = node.backward
|
||||
}
|
||||
for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil {
|
||||
skiplist.level--
|
||||
}
|
||||
skiplist.length--
|
||||
for i := int16(0); i < skiplist.level; i++ {
|
||||
if update[i].level[i].forward == node {
|
||||
update[i].level[i].span += node.level[i].span - 1
|
||||
update[i].level[i].forward = node.level[i].forward
|
||||
} else {
|
||||
update[i].level[i].span--
|
||||
}
|
||||
}
|
||||
if node.level[0].forward != nil {
|
||||
node.level[0].forward.backward = node.backward
|
||||
} else {
|
||||
skiplist.tail = node.backward
|
||||
}
|
||||
for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil {
|
||||
skiplist.level--
|
||||
}
|
||||
skiplist.length--
|
||||
}
|
||||
|
||||
/*
|
||||
* return: has found and removed node
|
||||
*/
|
||||
func (skiplist *skiplist) remove(member string, score float64)bool {
|
||||
/*
|
||||
* find backward node (of target) or last node of each level
|
||||
* their forward need to be updated
|
||||
*/
|
||||
update := make([]*Node, maxLevel)
|
||||
node := skiplist.header
|
||||
for i := skiplist.level - 1; i >= 0; i-- {
|
||||
for node.level[i].forward != nil &&
|
||||
(node.level[i].forward.Score < score ||
|
||||
(node.level[i].forward.Score == score &&
|
||||
node.level[i].forward.Member < member)) {
|
||||
node = node.level[i].forward
|
||||
}
|
||||
update[i] = node
|
||||
}
|
||||
node = node.level[0].forward
|
||||
if node != nil && score == node.Score && node.Member == member {
|
||||
skiplist.removeNode(node, update)
|
||||
// free x
|
||||
return true
|
||||
}
|
||||
return false
|
||||
func (skiplist *skiplist) remove(member string, score float64) bool {
|
||||
/*
|
||||
* find backward node (of target) or last node of each level
|
||||
* their forward need to be updated
|
||||
*/
|
||||
update := make([]*Node, maxLevel)
|
||||
node := skiplist.header
|
||||
for i := skiplist.level - 1; i >= 0; i-- {
|
||||
for node.level[i].forward != nil &&
|
||||
(node.level[i].forward.Score < score ||
|
||||
(node.level[i].forward.Score == score &&
|
||||
node.level[i].forward.Member < member)) {
|
||||
node = node.level[i].forward
|
||||
}
|
||||
update[i] = node
|
||||
}
|
||||
node = node.level[0].forward
|
||||
if node != nil && score == node.Score && node.Member == member {
|
||||
skiplist.removeNode(node, update)
|
||||
// free x
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/*
|
||||
* return: 1 based rank, 0 means member not found
|
||||
*/
|
||||
func (skiplist *skiplist) getRank(member string, score float64)int64 {
|
||||
var rank int64 = 0
|
||||
x := skiplist.header
|
||||
for i := skiplist.level - 1; i >= 0; i-- {
|
||||
for x.level[i].forward != nil &&
|
||||
(x.level[i].forward.Score < score ||
|
||||
(x.level[i].forward.Score == score &&
|
||||
x.level[i].forward.Member <= member)) {
|
||||
rank += x.level[i].span
|
||||
x = x.level[i].forward
|
||||
}
|
||||
func (skiplist *skiplist) getRank(member string, score float64) int64 {
|
||||
var rank int64 = 0
|
||||
x := skiplist.header
|
||||
for i := skiplist.level - 1; i >= 0; i-- {
|
||||
for x.level[i].forward != nil &&
|
||||
(x.level[i].forward.Score < score ||
|
||||
(x.level[i].forward.Score == score &&
|
||||
x.level[i].forward.Member <= member)) {
|
||||
rank += x.level[i].span
|
||||
x = x.level[i].forward
|
||||
}
|
||||
|
||||
/* x might be equal to zsl->header, so test if obj is non-NULL */
|
||||
if x.Member == member {
|
||||
return rank
|
||||
}
|
||||
}
|
||||
return 0
|
||||
/* x might be equal to zsl->header, so test if obj is non-NULL */
|
||||
if x.Member == member {
|
||||
return rank
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
/*
|
||||
* 1-based rank
|
||||
*/
|
||||
func (skiplist *skiplist) getByRank(rank int64)*Node {
|
||||
var i int64 = 0
|
||||
n := skiplist.header
|
||||
// scan from top level
|
||||
for level := skiplist.level - 1; level >= 0; level-- {
|
||||
for n.level[level].forward != nil && (i+n.level[level].span) <= rank {
|
||||
i += n.level[level].span
|
||||
n = n.level[level].forward
|
||||
}
|
||||
if i == rank {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func (skiplist *skiplist) getByRank(rank int64) *Node {
|
||||
var i int64 = 0
|
||||
n := skiplist.header
|
||||
// scan from top level
|
||||
for level := skiplist.level - 1; level >= 0; level-- {
|
||||
for n.level[level].forward != nil && (i+n.level[level].span) <= rank {
|
||||
i += n.level[level].span
|
||||
n = n.level[level].forward
|
||||
}
|
||||
if i == rank {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (skiplist *skiplist) hasInRange(min *ScoreBorder, max *ScoreBorder) bool {
|
||||
// min & max = empty
|
||||
if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) {
|
||||
return false
|
||||
}
|
||||
// min > tail
|
||||
n := skiplist.tail
|
||||
if n == nil || !min.less(n.Score) {
|
||||
return false
|
||||
}
|
||||
// max < head
|
||||
n = skiplist.header.level[0].forward
|
||||
if n == nil || !max.greater(n.Score) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
// min & max = empty
|
||||
if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) {
|
||||
return false
|
||||
}
|
||||
// min > tail
|
||||
n := skiplist.tail
|
||||
if n == nil || !min.less(n.Score) {
|
||||
return false
|
||||
}
|
||||
// max < head
|
||||
n = skiplist.header.level[0].forward
|
||||
if n == nil || !max.greater(n.Score) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
|
||||
if !skiplist.hasInRange(min, max) {
|
||||
return nil
|
||||
}
|
||||
n := skiplist.header
|
||||
// scan from top level
|
||||
for level := skiplist.level - 1; level >= 0; level-- {
|
||||
// if forward is not in range than move forward
|
||||
for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) {
|
||||
n = n.level[level].forward
|
||||
}
|
||||
}
|
||||
/* This is an inner range, so the next node cannot be NULL. */
|
||||
n = n.level[0].forward
|
||||
if !max.greater(n.Score) {
|
||||
return nil
|
||||
}
|
||||
return n
|
||||
if !skiplist.hasInRange(min, max) {
|
||||
return nil
|
||||
}
|
||||
n := skiplist.header
|
||||
// scan from top level
|
||||
for level := skiplist.level - 1; level >= 0; level-- {
|
||||
// if forward is not in range than move forward
|
||||
for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) {
|
||||
n = n.level[level].forward
|
||||
}
|
||||
}
|
||||
/* This is an inner range, so the next node cannot be NULL. */
|
||||
n = n.level[0].forward
|
||||
if !max.greater(n.Score) {
|
||||
return nil
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
|
||||
if !skiplist.hasInRange(min, max) {
|
||||
return nil
|
||||
}
|
||||
n := skiplist.header
|
||||
// scan from top level
|
||||
for level := skiplist.level - 1; level >= 0; level-- {
|
||||
for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) {
|
||||
n = n.level[level].forward
|
||||
}
|
||||
}
|
||||
if !min.less(n.Score) {
|
||||
return nil
|
||||
}
|
||||
return n
|
||||
if !skiplist.hasInRange(min, max) {
|
||||
return nil
|
||||
}
|
||||
n := skiplist.header
|
||||
// scan from top level
|
||||
for level := skiplist.level - 1; level >= 0; level-- {
|
||||
for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) {
|
||||
n = n.level[level].forward
|
||||
}
|
||||
}
|
||||
if !min.less(n.Score) {
|
||||
return nil
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
/*
|
||||
* return removed elements
|
||||
*/
|
||||
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)(removed []*Element) {
|
||||
update := make([]*Node, maxLevel)
|
||||
removed = make([]*Element, 0)
|
||||
// find backward nodes (of target range) or last node of each level
|
||||
node := skiplist.header
|
||||
for i := skiplist.level - 1; i >= 0; i-- {
|
||||
for node.level[i].forward != nil {
|
||||
if min.less(node.level[i].forward.Score) { // already in range
|
||||
break
|
||||
}
|
||||
node = node.level[i].forward
|
||||
}
|
||||
update[i] = node
|
||||
}
|
||||
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder) (removed []*Element) {
|
||||
update := make([]*Node, maxLevel)
|
||||
removed = make([]*Element, 0)
|
||||
// find backward nodes (of target range) or last node of each level
|
||||
node := skiplist.header
|
||||
for i := skiplist.level - 1; i >= 0; i-- {
|
||||
for node.level[i].forward != nil {
|
||||
if min.less(node.level[i].forward.Score) { // already in range
|
||||
break
|
||||
}
|
||||
node = node.level[i].forward
|
||||
}
|
||||
update[i] = node
|
||||
}
|
||||
|
||||
// node is the first one within range
|
||||
node = node.level[0].forward
|
||||
// node is the first one within range
|
||||
node = node.level[0].forward
|
||||
|
||||
// remove nodes in range
|
||||
for node != nil {
|
||||
if !max.greater(node.Score) { // already out of range
|
||||
break
|
||||
}
|
||||
next := node.level[0].forward
|
||||
removedElement := node.Element
|
||||
removed = append(removed, &removedElement)
|
||||
skiplist.removeNode(node, update)
|
||||
node = next
|
||||
}
|
||||
return removed
|
||||
// remove nodes in range
|
||||
for node != nil {
|
||||
if !max.greater(node.Score) { // already out of range
|
||||
break
|
||||
}
|
||||
next := node.level[0].forward
|
||||
removedElement := node.Element
|
||||
removed = append(removed, &removedElement)
|
||||
skiplist.removeNode(node, update)
|
||||
node = next
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
// 1-based rank, including start, exclude stop
|
||||
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64)(removed []*Element) {
|
||||
var i int64 = 0 // rank of iterator
|
||||
update := make([]*Node, maxLevel)
|
||||
removed = make([]*Element, 0)
|
||||
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64) (removed []*Element) {
|
||||
var i int64 = 0 // rank of iterator
|
||||
update := make([]*Node, maxLevel)
|
||||
removed = make([]*Element, 0)
|
||||
|
||||
// scan from top level
|
||||
node := skiplist.header
|
||||
for level := skiplist.level - 1; level >= 0; level-- {
|
||||
for node.level[level].forward != nil && (i+node.level[level].span) < start {
|
||||
i += node.level[level].span
|
||||
node = node.level[level].forward
|
||||
}
|
||||
update[level] = node
|
||||
}
|
||||
// scan from top level
|
||||
node := skiplist.header
|
||||
for level := skiplist.level - 1; level >= 0; level-- {
|
||||
for node.level[level].forward != nil && (i+node.level[level].span) < start {
|
||||
i += node.level[level].span
|
||||
node = node.level[level].forward
|
||||
}
|
||||
update[level] = node
|
||||
}
|
||||
|
||||
i++
|
||||
node = node.level[0].forward // first node in range
|
||||
i++
|
||||
node = node.level[0].forward // first node in range
|
||||
|
||||
// remove nodes in range
|
||||
for node != nil && i < stop {
|
||||
next := node.level[0].forward
|
||||
removedElement := node.Element
|
||||
removed = append(removed, &removedElement)
|
||||
skiplist.removeNode(node, update)
|
||||
node = next
|
||||
i++
|
||||
}
|
||||
return removed
|
||||
// remove nodes in range
|
||||
for node != nil && i < stop {
|
||||
next := node.level[0].forward
|
||||
removedElement := node.Element
|
||||
removed = append(removed, &removedElement)
|
||||
skiplist.removeNode(node, update)
|
||||
node = next
|
||||
i++
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
@@ -1,227 +1,226 @@
|
||||
package sortedset
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type SortedSet struct {
|
||||
dict map[string]*Element
|
||||
skiplist *skiplist
|
||||
dict map[string]*Element
|
||||
skiplist *skiplist
|
||||
}
|
||||
|
||||
func Make()*SortedSet {
|
||||
return &SortedSet{
|
||||
dict: make(map[string]*Element),
|
||||
skiplist: makeSkiplist(),
|
||||
}
|
||||
func Make() *SortedSet {
|
||||
return &SortedSet{
|
||||
dict: make(map[string]*Element),
|
||||
skiplist: makeSkiplist(),
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* return: has inserted new node
|
||||
*/
|
||||
func (sortedSet *SortedSet)Add(member string, score float64)bool {
|
||||
element, ok := sortedSet.dict[member]
|
||||
sortedSet.dict[member] = &Element{
|
||||
Member: member,
|
||||
Score: score,
|
||||
}
|
||||
if ok {
|
||||
if score != element.Score {
|
||||
sortedSet.skiplist.remove(member, score)
|
||||
sortedSet.skiplist.insert(member, score)
|
||||
}
|
||||
return false
|
||||
} else {
|
||||
sortedSet.skiplist.insert(member, score)
|
||||
return true
|
||||
}
|
||||
func (sortedSet *SortedSet) Add(member string, score float64) bool {
|
||||
element, ok := sortedSet.dict[member]
|
||||
sortedSet.dict[member] = &Element{
|
||||
Member: member,
|
||||
Score: score,
|
||||
}
|
||||
if ok {
|
||||
if score != element.Score {
|
||||
sortedSet.skiplist.remove(member, score)
|
||||
sortedSet.skiplist.insert(member, score)
|
||||
}
|
||||
return false
|
||||
} else {
|
||||
sortedSet.skiplist.insert(member, score)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (sortedSet *SortedSet) Len()int64 {
|
||||
return int64(len(sortedSet.dict))
|
||||
func (sortedSet *SortedSet) Len() int64 {
|
||||
return int64(len(sortedSet.dict))
|
||||
}
|
||||
|
||||
func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) {
|
||||
element, ok = sortedSet.dict[member]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return element, true
|
||||
element, ok = sortedSet.dict[member]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return element, true
|
||||
}
|
||||
|
||||
func (sortedSet *SortedSet) Remove(member string)bool {
|
||||
v, ok := sortedSet.dict[member]
|
||||
if ok {
|
||||
sortedSet.skiplist.remove(member, v.Score)
|
||||
delete(sortedSet.dict, member)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
func (sortedSet *SortedSet) Remove(member string) bool {
|
||||
v, ok := sortedSet.dict[member]
|
||||
if ok {
|
||||
sortedSet.skiplist.remove(member, v.Score)
|
||||
delete(sortedSet.dict, member)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* get 0-based rank
|
||||
*/
|
||||
func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) {
|
||||
element, ok := sortedSet.dict[member]
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
r := sortedSet.skiplist.getRank(member, element.Score)
|
||||
if desc {
|
||||
r = sortedSet.skiplist.length - r
|
||||
} else {
|
||||
r--
|
||||
}
|
||||
return r
|
||||
element, ok := sortedSet.dict[member]
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
r := sortedSet.skiplist.getRank(member, element.Score)
|
||||
if desc {
|
||||
r = sortedSet.skiplist.length - r
|
||||
} else {
|
||||
r--
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
/**
|
||||
* traverse [start, stop), 0-based rank
|
||||
*/
|
||||
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element)bool) {
|
||||
size := int64(sortedSet.Len())
|
||||
if start < 0 || start >= size {
|
||||
panic("illegal start " + strconv.FormatInt(start, 10))
|
||||
}
|
||||
if stop < start || stop > size {
|
||||
panic("illegal end " + strconv.FormatInt(stop, 10))
|
||||
}
|
||||
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element) bool) {
|
||||
size := int64(sortedSet.Len())
|
||||
if start < 0 || start >= size {
|
||||
panic("illegal start " + strconv.FormatInt(start, 10))
|
||||
}
|
||||
if stop < start || stop > size {
|
||||
panic("illegal end " + strconv.FormatInt(stop, 10))
|
||||
}
|
||||
|
||||
// find start node
|
||||
var node *Node
|
||||
if desc {
|
||||
node = sortedSet.skiplist.tail
|
||||
if start > 0 {
|
||||
node = sortedSet.skiplist.getByRank(int64(size - start))
|
||||
}
|
||||
} else {
|
||||
node = sortedSet.skiplist.header.level[0].forward
|
||||
if start > 0 {
|
||||
node = sortedSet.skiplist.getByRank(int64(start + 1))
|
||||
}
|
||||
}
|
||||
// find start node
|
||||
var node *Node
|
||||
if desc {
|
||||
node = sortedSet.skiplist.tail
|
||||
if start > 0 {
|
||||
node = sortedSet.skiplist.getByRank(int64(size - start))
|
||||
}
|
||||
} else {
|
||||
node = sortedSet.skiplist.header.level[0].forward
|
||||
if start > 0 {
|
||||
node = sortedSet.skiplist.getByRank(int64(start + 1))
|
||||
}
|
||||
}
|
||||
|
||||
sliceSize := int(stop - start)
|
||||
for i := 0; i < sliceSize; i++ {
|
||||
if !consumer(&node.Element) {
|
||||
break
|
||||
}
|
||||
if desc {
|
||||
node = node.backward
|
||||
} else {
|
||||
node = node.level[0].forward
|
||||
}
|
||||
}
|
||||
sliceSize := int(stop - start)
|
||||
for i := 0; i < sliceSize; i++ {
|
||||
if !consumer(&node.Element) {
|
||||
break
|
||||
}
|
||||
if desc {
|
||||
node = node.backward
|
||||
} else {
|
||||
node = node.level[0].forward
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* return [start, stop), 0-based rank
|
||||
* assert start in [0, size), stop in [start, size]
|
||||
*/
|
||||
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool)[]*Element {
|
||||
sliceSize := int(stop - start)
|
||||
slice := make([]*Element, sliceSize)
|
||||
i := 0
|
||||
sortedSet.ForEach(start, stop, desc, func(element *Element)bool {
|
||||
slice[i] = element
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return slice
|
||||
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element {
|
||||
sliceSize := int(stop - start)
|
||||
slice := make([]*Element, sliceSize)
|
||||
i := 0
|
||||
sortedSet.ForEach(start, stop, desc, func(element *Element) bool {
|
||||
slice[i] = element
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return slice
|
||||
}
|
||||
|
||||
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder)int64 {
|
||||
var i int64 = 0
|
||||
// ascending order
|
||||
sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool {
|
||||
gtMin := min.less(element.Score) // greater than min
|
||||
if !gtMin {
|
||||
// has not into range, continue foreach
|
||||
return true
|
||||
}
|
||||
ltMax := max.greater(element.Score) // less than max
|
||||
if !ltMax {
|
||||
// break through score border, break foreach
|
||||
return false
|
||||
}
|
||||
// gtMin && ltMax
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return i
|
||||
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 {
|
||||
var i int64 = 0
|
||||
// ascending order
|
||||
sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool {
|
||||
gtMin := min.less(element.Score) // greater than min
|
||||
if !gtMin {
|
||||
// has not into range, continue foreach
|
||||
return true
|
||||
}
|
||||
ltMax := max.greater(element.Score) // less than max
|
||||
if !ltMax {
|
||||
// break through score border, break foreach
|
||||
return false
|
||||
}
|
||||
// gtMin && ltMax
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return i
|
||||
}
|
||||
|
||||
func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) {
|
||||
// find start node
|
||||
var node *Node
|
||||
if desc {
|
||||
node = sortedSet.skiplist.getLastInScoreRange(min, max)
|
||||
} else {
|
||||
node = sortedSet.skiplist.getFirstInScoreRange(min, max)
|
||||
}
|
||||
// find start node
|
||||
var node *Node
|
||||
if desc {
|
||||
node = sortedSet.skiplist.getLastInScoreRange(min, max)
|
||||
} else {
|
||||
node = sortedSet.skiplist.getFirstInScoreRange(min, max)
|
||||
}
|
||||
|
||||
for node != nil && offset > 0 {
|
||||
if desc {
|
||||
node = node.backward
|
||||
} else {
|
||||
node = node.level[0].forward
|
||||
}
|
||||
offset--
|
||||
}
|
||||
for node != nil && offset > 0 {
|
||||
if desc {
|
||||
node = node.backward
|
||||
} else {
|
||||
node = node.level[0].forward
|
||||
}
|
||||
offset--
|
||||
}
|
||||
|
||||
// A negative limit returns all elements from the offset
|
||||
for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ {
|
||||
if !consumer(&node.Element) {
|
||||
break
|
||||
}
|
||||
if desc {
|
||||
node = node.backward
|
||||
} else {
|
||||
node = node.level[0].forward
|
||||
}
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
gtMin := min.less(node.Element.Score) // greater than min
|
||||
ltMax := max.greater(node.Element.Score)
|
||||
if !gtMin || !ltMax {
|
||||
break // break through score border
|
||||
}
|
||||
}
|
||||
// A negative limit returns all elements from the offset
|
||||
for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ {
|
||||
if !consumer(&node.Element) {
|
||||
break
|
||||
}
|
||||
if desc {
|
||||
node = node.backward
|
||||
} else {
|
||||
node = node.level[0].forward
|
||||
}
|
||||
if node == nil {
|
||||
break
|
||||
}
|
||||
gtMin := min.less(node.Element.Score) // greater than min
|
||||
ltMax := max.greater(node.Element.Score)
|
||||
if !gtMin || !ltMax {
|
||||
break // break through score border
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* param limit: <0 means no limit
|
||||
*/
|
||||
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool)[]*Element {
|
||||
if limit == 0 || offset < 0{
|
||||
return make([]*Element, 0)
|
||||
}
|
||||
slice := make([]*Element, 0)
|
||||
sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool {
|
||||
slice = append(slice, element)
|
||||
return true
|
||||
})
|
||||
return slice
|
||||
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool) []*Element {
|
||||
if limit == 0 || offset < 0 {
|
||||
return make([]*Element, 0)
|
||||
}
|
||||
slice := make([]*Element, 0)
|
||||
sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool {
|
||||
slice = append(slice, element)
|
||||
return true
|
||||
})
|
||||
return slice
|
||||
}
|
||||
|
||||
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder)int64 {
|
||||
removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
|
||||
for _, element := range removed {
|
||||
delete(sortedSet.dict, element.Member)
|
||||
}
|
||||
return int64(len(removed))
|
||||
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) int64 {
|
||||
removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
|
||||
for _, element := range removed {
|
||||
delete(sortedSet.dict, element.Member)
|
||||
}
|
||||
return int64(len(removed))
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* 0-based rank, [start, stop)
|
||||
*/
|
||||
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64)int64 {
|
||||
removed := sortedSet.skiplist.RemoveRangeByRank(start + 1, stop + 1)
|
||||
for _, element := range removed {
|
||||
delete(sortedSet.dict, element.Member)
|
||||
}
|
||||
return int64(len(removed))
|
||||
}
|
||||
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 {
|
||||
removed := sortedSet.skiplist.RemoveRangeByRank(start+1, stop+1)
|
||||
for _, element := range removed {
|
||||
delete(sortedSet.dict, element.Member)
|
||||
}
|
||||
return int64(len(removed))
|
||||
}
|
||||
|
@@ -1,28 +1,28 @@
|
||||
package utils
|
||||
|
||||
func Equals(a interface{}, b interface{})bool {
|
||||
sliceA, okA := a.([]byte)
|
||||
sliceB, okB := b.([]byte)
|
||||
if okA && okB {
|
||||
return BytesEquals(sliceA, sliceB)
|
||||
}
|
||||
return a == b
|
||||
func Equals(a interface{}, b interface{}) bool {
|
||||
sliceA, okA := a.([]byte)
|
||||
sliceB, okB := b.([]byte)
|
||||
if okA && okB {
|
||||
return BytesEquals(sliceA, sliceB)
|
||||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
func BytesEquals(a []byte, b []byte) bool {
|
||||
if (a == nil && b != nil) || (a != nil && b == nil) {
|
||||
return false
|
||||
}
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
size := len(a)
|
||||
for i := 0; i < size; i++ {
|
||||
av := a[i]
|
||||
bv := b[i]
|
||||
if av != bv {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
if (a == nil && b != nil) || (a != nil && b == nil) {
|
||||
return false
|
||||
}
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
size := len(a)
|
||||
for i := 0; i < size; i++ {
|
||||
av := a[i]
|
||||
bv := b[i]
|
||||
if av != bv {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@@ -2,7 +2,7 @@ package db
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"github.com/HDT3213/godis/src/config"
|
||||
"github.com/HDT3213/godis/src/config"
|
||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||
List "github.com/HDT3213/godis/src/datastruct/list"
|
||||
"github.com/HDT3213/godis/src/datastruct/lock"
|
||||
@@ -29,18 +29,18 @@ func makeExpireCmd(key string, expireAt time.Time) *reply.MultiBulkReply {
|
||||
}
|
||||
|
||||
func makeAofCmd(cmd string, args [][]byte) *reply.MultiBulkReply {
|
||||
params := make([][]byte, len(args)+1)
|
||||
copy(params[1:], args)
|
||||
params[0] = []byte(cmd)
|
||||
return reply.MakeMultiBulkReply(params)
|
||||
params := make([][]byte, len(args)+1)
|
||||
copy(params[1:], args)
|
||||
params[0] = []byte(cmd)
|
||||
return reply.MakeMultiBulkReply(params)
|
||||
}
|
||||
|
||||
// send command to aof
|
||||
func (db *DB) AddAof(args *reply.MultiBulkReply) {
|
||||
// aofChan == nil when loadAof
|
||||
if config.Properties.AppendOnly && db.aofChan != nil {
|
||||
db.aofChan <- args
|
||||
}
|
||||
// aofChan == nil when loadAof
|
||||
if config.Properties.AppendOnly && db.aofChan != nil {
|
||||
db.aofChan <- args
|
||||
}
|
||||
}
|
||||
|
||||
// listen aof channel and write into file
|
||||
@@ -72,12 +72,12 @@ func trim(msg []byte) string {
|
||||
|
||||
// read aof file
|
||||
func (db *DB) loadAof(maxBytes int) {
|
||||
// delete aofChan to prevent write again
|
||||
aofChan := db.aofChan
|
||||
db.aofChan = nil
|
||||
defer func(aofChan chan *reply.MultiBulkReply) {
|
||||
db.aofChan = aofChan
|
||||
}(aofChan)
|
||||
// delete aofChan to prevent write again
|
||||
aofChan := db.aofChan
|
||||
db.aofChan = nil
|
||||
defer func(aofChan chan *reply.MultiBulkReply) {
|
||||
db.aofChan = aofChan
|
||||
}(aofChan)
|
||||
|
||||
file, err := os.Open(db.aofFilename)
|
||||
if err != nil {
|
||||
|
366
src/db/db.go
366
src/db/db.go
@@ -1,297 +1,297 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/config"
|
||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||
List "github.com/HDT3213/godis/src/datastruct/list"
|
||||
"github.com/HDT3213/godis/src/datastruct/lock"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/pubsub"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/config"
|
||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||
List "github.com/HDT3213/godis/src/datastruct/list"
|
||||
"github.com/HDT3213/godis/src/datastruct/lock"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/pubsub"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DataEntity struct {
|
||||
Data interface{}
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
const (
|
||||
dataDictSize = 1 << 16
|
||||
ttlDictSize = 1 << 10
|
||||
lockerSize = 128
|
||||
aofQueueSize = 1 << 16
|
||||
dataDictSize = 1 << 16
|
||||
ttlDictSize = 1 << 10
|
||||
lockerSize = 128
|
||||
aofQueueSize = 1 << 16
|
||||
)
|
||||
|
||||
// args don't include cmd line
|
||||
type CmdFunc func(db *DB, args [][]byte) redis.Reply
|
||||
|
||||
type DB struct {
|
||||
// key -> DataEntity
|
||||
Data dict.Dict
|
||||
// key -> expireTime (time.Time)
|
||||
TTLMap dict.Dict
|
||||
// channel -> list<*client>
|
||||
SubMap dict.Dict
|
||||
// key -> DataEntity
|
||||
Data dict.Dict
|
||||
// key -> expireTime (time.Time)
|
||||
TTLMap dict.Dict
|
||||
// channel -> list<*client>
|
||||
SubMap dict.Dict
|
||||
|
||||
// dict will ensure thread safety of its method
|
||||
// use this mutex for complicated command only, eg. rpush, incr ...
|
||||
Locker *lock.Locks
|
||||
// dict will ensure thread safety of its method
|
||||
// use this mutex for complicated command only, eg. rpush, incr ...
|
||||
Locker *lock.Locks
|
||||
|
||||
// TimerTask interval
|
||||
interval time.Duration
|
||||
// TimerTask interval
|
||||
interval time.Duration
|
||||
|
||||
stopWorld sync.WaitGroup
|
||||
stopWorld sync.WaitGroup
|
||||
|
||||
hub *pubsub.Hub
|
||||
hub *pubsub.Hub
|
||||
|
||||
// main goroutine send commands to aof goroutine through aofChan
|
||||
aofChan chan *reply.MultiBulkReply
|
||||
aofFile *os.File
|
||||
aofFilename string
|
||||
// main goroutine send commands to aof goroutine through aofChan
|
||||
aofChan chan *reply.MultiBulkReply
|
||||
aofFile *os.File
|
||||
aofFilename string
|
||||
|
||||
aofRewriteChan chan *reply.MultiBulkReply
|
||||
pausingAof sync.RWMutex
|
||||
aofRewriteChan chan *reply.MultiBulkReply
|
||||
pausingAof sync.RWMutex
|
||||
}
|
||||
|
||||
var router = MakeRouter()
|
||||
|
||||
func MakeDB() *DB {
|
||||
db := &DB{
|
||||
Data: dict.MakeConcurrent(dataDictSize),
|
||||
TTLMap: dict.MakeConcurrent(ttlDictSize),
|
||||
Locker: lock.Make(lockerSize),
|
||||
interval: 5 * time.Second,
|
||||
hub: pubsub.MakeHub(),
|
||||
}
|
||||
db := &DB{
|
||||
Data: dict.MakeConcurrent(dataDictSize),
|
||||
TTLMap: dict.MakeConcurrent(ttlDictSize),
|
||||
Locker: lock.Make(lockerSize),
|
||||
interval: 5 * time.Second,
|
||||
hub: pubsub.MakeHub(),
|
||||
}
|
||||
|
||||
// aof
|
||||
if config.Properties.AppendOnly {
|
||||
db.aofFilename = config.Properties.AppendFilename
|
||||
db.loadAof(0)
|
||||
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
db.aofFile = aofFile
|
||||
db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize)
|
||||
}
|
||||
go func() {
|
||||
db.handleAof()
|
||||
}()
|
||||
}
|
||||
// aof
|
||||
if config.Properties.AppendOnly {
|
||||
db.aofFilename = config.Properties.AppendFilename
|
||||
db.loadAof(0)
|
||||
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
} else {
|
||||
db.aofFile = aofFile
|
||||
db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize)
|
||||
}
|
||||
go func() {
|
||||
db.handleAof()
|
||||
}()
|
||||
}
|
||||
|
||||
// start timer
|
||||
db.TimerTask()
|
||||
return db
|
||||
// start timer
|
||||
db.TimerTask()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) Close() {
|
||||
if db.aofFile != nil {
|
||||
err := db.aofFile.Close()
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
}
|
||||
}
|
||||
if db.aofFile != nil {
|
||||
err := db.aofFile.Close()
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) Exec(c redis.Connection, args [][]byte) (result redis.Reply) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
|
||||
result = &reply.UnknownErrReply{}
|
||||
}
|
||||
}()
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
|
||||
result = &reply.UnknownErrReply{}
|
||||
}
|
||||
}()
|
||||
|
||||
cmd := strings.ToLower(string(args[0]))
|
||||
cmd := strings.ToLower(string(args[0]))
|
||||
|
||||
// special commands
|
||||
if cmd == "subscribe" {
|
||||
if len(args) < 2 {
|
||||
return &reply.ArgNumErrReply{Cmd: "subscribe"}
|
||||
}
|
||||
return pubsub.Subscribe(db.hub, c, args[1:])
|
||||
} else if cmd == "publish" {
|
||||
return pubsub.Publish(db.hub, args[1:])
|
||||
} else if cmd == "unsubscribe" {
|
||||
return pubsub.UnSubscribe(db.hub, c, args[1:])
|
||||
} else if cmd == "bgrewriteaof" {
|
||||
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
|
||||
reply := BGRewriteAOF(db, args[1:])
|
||||
return reply
|
||||
}
|
||||
// special commands
|
||||
if cmd == "subscribe" {
|
||||
if len(args) < 2 {
|
||||
return &reply.ArgNumErrReply{Cmd: "subscribe"}
|
||||
}
|
||||
return pubsub.Subscribe(db.hub, c, args[1:])
|
||||
} else if cmd == "publish" {
|
||||
return pubsub.Publish(db.hub, args[1:])
|
||||
} else if cmd == "unsubscribe" {
|
||||
return pubsub.UnSubscribe(db.hub, c, args[1:])
|
||||
} else if cmd == "bgrewriteaof" {
|
||||
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
|
||||
reply := BGRewriteAOF(db, args[1:])
|
||||
return reply
|
||||
}
|
||||
|
||||
// normal commands
|
||||
cmdFunc, ok := router[cmd]
|
||||
if !ok {
|
||||
return reply.MakeErrReply("ERR unknown command '" + cmd + "'")
|
||||
}
|
||||
if len(args) > 1 {
|
||||
result = cmdFunc(db, args[1:])
|
||||
} else {
|
||||
result = cmdFunc(db, [][]byte{})
|
||||
}
|
||||
// normal commands
|
||||
cmdFunc, ok := router[cmd]
|
||||
if !ok {
|
||||
return reply.MakeErrReply("ERR unknown command '" + cmd + "'")
|
||||
}
|
||||
if len(args) > 1 {
|
||||
result = cmdFunc(db, args[1:])
|
||||
} else {
|
||||
result = cmdFunc(db, [][]byte{})
|
||||
}
|
||||
|
||||
// aof
|
||||
// aof
|
||||
|
||||
return
|
||||
return
|
||||
}
|
||||
|
||||
/* ---- Data Access ----- */
|
||||
|
||||
func (db *DB) Get(key string) (*DataEntity, bool) {
|
||||
db.stopWorld.Wait()
|
||||
db.stopWorld.Wait()
|
||||
|
||||
raw, ok := db.Data.Get(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if db.IsExpired(key) {
|
||||
return nil, false
|
||||
}
|
||||
entity, _ := raw.(*DataEntity)
|
||||
return entity, true
|
||||
raw, ok := db.Data.Get(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
if db.IsExpired(key) {
|
||||
return nil, false
|
||||
}
|
||||
entity, _ := raw.(*DataEntity)
|
||||
return entity, true
|
||||
}
|
||||
|
||||
func (db *DB) Put(key string, entity *DataEntity) int {
|
||||
db.stopWorld.Wait()
|
||||
return db.Data.Put(key, entity)
|
||||
db.stopWorld.Wait()
|
||||
return db.Data.Put(key, entity)
|
||||
}
|
||||
|
||||
func (db *DB) PutIfExists(key string, entity *DataEntity) int {
|
||||
db.stopWorld.Wait()
|
||||
return db.Data.PutIfExists(key, entity)
|
||||
db.stopWorld.Wait()
|
||||
return db.Data.PutIfExists(key, entity)
|
||||
}
|
||||
|
||||
func (db *DB) PutIfAbsent(key string, entity *DataEntity) int {
|
||||
db.stopWorld.Wait()
|
||||
return db.Data.PutIfAbsent(key, entity)
|
||||
db.stopWorld.Wait()
|
||||
return db.Data.PutIfAbsent(key, entity)
|
||||
}
|
||||
|
||||
func (db *DB) Remove(key string) {
|
||||
db.stopWorld.Wait()
|
||||
db.Data.Remove(key)
|
||||
db.TTLMap.Remove(key)
|
||||
db.stopWorld.Wait()
|
||||
db.Data.Remove(key)
|
||||
db.TTLMap.Remove(key)
|
||||
}
|
||||
|
||||
func (db *DB) Removes(keys ...string) (deleted int) {
|
||||
db.stopWorld.Wait()
|
||||
deleted = 0
|
||||
for _, key := range keys {
|
||||
_, exists := db.Data.Get(key)
|
||||
if exists {
|
||||
db.Data.Remove(key)
|
||||
db.TTLMap.Remove(key)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return deleted
|
||||
db.stopWorld.Wait()
|
||||
deleted = 0
|
||||
for _, key := range keys {
|
||||
_, exists := db.Data.Get(key)
|
||||
if exists {
|
||||
db.Data.Remove(key)
|
||||
db.TTLMap.Remove(key)
|
||||
deleted++
|
||||
}
|
||||
}
|
||||
return deleted
|
||||
}
|
||||
|
||||
func (db *DB) Flush() {
|
||||
db.stopWorld.Add(1)
|
||||
defer db.stopWorld.Done()
|
||||
db.stopWorld.Add(1)
|
||||
defer db.stopWorld.Done()
|
||||
|
||||
db.Data = dict.MakeConcurrent(dataDictSize)
|
||||
db.TTLMap = dict.MakeConcurrent(ttlDictSize)
|
||||
db.Locker = lock.Make(lockerSize)
|
||||
db.Data = dict.MakeConcurrent(dataDictSize)
|
||||
db.TTLMap = dict.MakeConcurrent(ttlDictSize)
|
||||
db.Locker = lock.Make(lockerSize)
|
||||
|
||||
}
|
||||
|
||||
/* ---- Lock Function ----- */
|
||||
|
||||
func (db *DB) Lock(key string) {
|
||||
db.Locker.Lock(key)
|
||||
db.Locker.Lock(key)
|
||||
}
|
||||
|
||||
func (db *DB) RLock(key string) {
|
||||
db.Locker.RLock(key)
|
||||
db.Locker.RLock(key)
|
||||
}
|
||||
|
||||
func (db *DB) UnLock(key string) {
|
||||
db.Locker.UnLock(key)
|
||||
db.Locker.UnLock(key)
|
||||
}
|
||||
|
||||
func (db *DB) RUnLock(key string) {
|
||||
db.Locker.RUnLock(key)
|
||||
db.Locker.RUnLock(key)
|
||||
}
|
||||
|
||||
func (db *DB) Locks(keys ...string) {
|
||||
db.Locker.Locks(keys...)
|
||||
db.Locker.Locks(keys...)
|
||||
}
|
||||
|
||||
func (db *DB) RLocks(keys ...string) {
|
||||
db.Locker.RLocks(keys...)
|
||||
db.Locker.RLocks(keys...)
|
||||
}
|
||||
|
||||
func (db *DB) UnLocks(keys ...string) {
|
||||
db.Locker.UnLocks(keys...)
|
||||
db.Locker.UnLocks(keys...)
|
||||
}
|
||||
|
||||
func (db *DB) RUnLocks(keys ...string) {
|
||||
db.Locker.RUnLocks(keys...)
|
||||
db.Locker.RUnLocks(keys...)
|
||||
}
|
||||
|
||||
/* ---- TTL Functions ---- */
|
||||
|
||||
func (db *DB) Expire(key string, expireTime time.Time) {
|
||||
db.stopWorld.Wait()
|
||||
db.TTLMap.Put(key, expireTime)
|
||||
db.stopWorld.Wait()
|
||||
db.TTLMap.Put(key, expireTime)
|
||||
}
|
||||
|
||||
func (db *DB) Persist(key string) {
|
||||
db.stopWorld.Wait()
|
||||
db.TTLMap.Remove(key)
|
||||
db.stopWorld.Wait()
|
||||
db.TTLMap.Remove(key)
|
||||
}
|
||||
|
||||
func (db *DB) IsExpired(key string) bool {
|
||||
rawExpireTime, ok := db.TTLMap.Get(key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
expireTime, _ := rawExpireTime.(time.Time)
|
||||
expired := time.Now().After(expireTime)
|
||||
if expired {
|
||||
db.Remove(key)
|
||||
}
|
||||
return expired
|
||||
rawExpireTime, ok := db.TTLMap.Get(key)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
expireTime, _ := rawExpireTime.(time.Time)
|
||||
expired := time.Now().After(expireTime)
|
||||
if expired {
|
||||
db.Remove(key)
|
||||
}
|
||||
return expired
|
||||
}
|
||||
|
||||
func (db *DB) CleanExpired() {
|
||||
now := time.Now()
|
||||
toRemove := &List.LinkedList{}
|
||||
db.TTLMap.ForEach(func(key string, val interface{}) bool {
|
||||
expireTime, _ := val.(time.Time)
|
||||
if now.After(expireTime) {
|
||||
// expired
|
||||
db.Data.Remove(key)
|
||||
toRemove.Add(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
toRemove.ForEach(func(i int, val interface{}) bool {
|
||||
key, _ := val.(string)
|
||||
db.TTLMap.Remove(key)
|
||||
return true
|
||||
})
|
||||
now := time.Now()
|
||||
toRemove := &List.LinkedList{}
|
||||
db.TTLMap.ForEach(func(key string, val interface{}) bool {
|
||||
expireTime, _ := val.(time.Time)
|
||||
if now.After(expireTime) {
|
||||
// expired
|
||||
db.Data.Remove(key)
|
||||
toRemove.Add(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
toRemove.ForEach(func(i int, val interface{}) bool {
|
||||
key, _ := val.(string)
|
||||
db.TTLMap.Remove(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (db *DB) TimerTask() {
|
||||
ticker := time.NewTicker(db.interval)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
db.CleanExpired()
|
||||
}
|
||||
}()
|
||||
ticker := time.NewTicker(db.interval)
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
db.CleanExpired()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
/* ---- Subscribe Functions ---- */
|
||||
|
||||
func (db *DB) AfterClientClose(c redis.Connection) {
|
||||
pubsub.UnsubscribeAll(db.hub, c)
|
||||
pubsub.UnsubscribeAll(db.hub, c)
|
||||
}
|
||||
|
674
src/db/hash.go
674
src/db/hash.go
@@ -1,422 +1,422 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
Dict "github.com/HDT3213/godis/src/datastruct/dict"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"github.com/shopspring/decimal"
|
||||
"strconv"
|
||||
Dict "github.com/HDT3213/godis/src/datastruct/dict"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"github.com/shopspring/decimal"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (db *DB) getAsDict(key string) (Dict.Dict, reply.ErrorReply) {
|
||||
entity, exists := db.Get(key)
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
dict, ok := entity.Data.(Dict.Dict)
|
||||
if !ok {
|
||||
return nil, &reply.WrongTypeErrReply{}
|
||||
}
|
||||
return dict, nil
|
||||
entity, exists := db.Get(key)
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
dict, ok := entity.Data.(Dict.Dict)
|
||||
if !ok {
|
||||
return nil, &reply.WrongTypeErrReply{}
|
||||
}
|
||||
return dict, nil
|
||||
}
|
||||
|
||||
func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply reply.ErrorReply) {
|
||||
dict, errReply = db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return nil, false, errReply
|
||||
}
|
||||
inited = false
|
||||
if dict == nil {
|
||||
dict = Dict.MakeSimple()
|
||||
db.Put(key, &DataEntity{
|
||||
Data: dict,
|
||||
})
|
||||
inited = true
|
||||
}
|
||||
return dict, inited, nil
|
||||
dict, errReply = db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return nil, false, errReply
|
||||
}
|
||||
inited = false
|
||||
if dict == nil {
|
||||
dict = Dict.MakeSimple()
|
||||
db.Put(key, &DataEntity{
|
||||
Data: dict,
|
||||
})
|
||||
inited = true
|
||||
}
|
||||
return dict, inited, nil
|
||||
}
|
||||
|
||||
func HSet(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
value := args[2]
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
value := args[2]
|
||||
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
// get or init entity
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
// get or init entity
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
|
||||
result := dict.Put(field, value)
|
||||
db.AddAof(makeAofCmd("hset", args))
|
||||
return reply.MakeIntReply(int64(result))
|
||||
result := dict.Put(field, value)
|
||||
db.AddAof(makeAofCmd("hset", args))
|
||||
return reply.MakeIntReply(int64(result))
|
||||
}
|
||||
|
||||
func HSetNX(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
value := args[2]
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
value := args[2]
|
||||
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
|
||||
result := dict.PutIfAbsent(field, value)
|
||||
if result > 0 {
|
||||
db.AddAof(makeAofCmd("hsetnx", args))
|
||||
result := dict.PutIfAbsent(field, value)
|
||||
if result > 0 {
|
||||
db.AddAof(makeAofCmd("hsetnx", args))
|
||||
|
||||
}
|
||||
return reply.MakeIntReply(int64(result))
|
||||
}
|
||||
return reply.MakeIntReply(int64(result))
|
||||
}
|
||||
|
||||
func HGet(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
// parse args
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
|
||||
raw, exists := dict.Get(field)
|
||||
if !exists {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
value, _ := raw.([]byte)
|
||||
return reply.MakeBulkReply(value)
|
||||
raw, exists := dict.Get(field)
|
||||
if !exists {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
value, _ := raw.([]byte)
|
||||
return reply.MakeBulkReply(value)
|
||||
}
|
||||
|
||||
func HExists(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
// parse args
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
|
||||
_, exists := dict.Get(field)
|
||||
if exists {
|
||||
return reply.MakeIntReply(1)
|
||||
}
|
||||
return reply.MakeIntReply(0)
|
||||
_, exists := dict.Get(field)
|
||||
if exists {
|
||||
return reply.MakeIntReply(1)
|
||||
}
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
|
||||
func HDel(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
fields := make([]string, len(args) - 1)
|
||||
fieldArgs := args[1:]
|
||||
for i, v := range fieldArgs {
|
||||
fields[i] = string(v)
|
||||
}
|
||||
// parse args
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
fields := make([]string, len(args)-1)
|
||||
fieldArgs := args[1:]
|
||||
for i, v := range fieldArgs {
|
||||
fields[i] = string(v)
|
||||
}
|
||||
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
|
||||
deleted := 0
|
||||
for _, field := range fields {
|
||||
result := dict.Remove(field)
|
||||
deleted += result
|
||||
}
|
||||
if dict.Len() == 0 {
|
||||
db.Remove(key)
|
||||
}
|
||||
if deleted > 0 {
|
||||
db.AddAof(makeAofCmd("hdel", args))
|
||||
}
|
||||
deleted := 0
|
||||
for _, field := range fields {
|
||||
result := dict.Remove(field)
|
||||
deleted += result
|
||||
}
|
||||
if dict.Len() == 0 {
|
||||
db.Remove(key)
|
||||
}
|
||||
if deleted > 0 {
|
||||
db.AddAof(makeAofCmd("hdel", args))
|
||||
}
|
||||
|
||||
return reply.MakeIntReply(int64(deleted))
|
||||
return reply.MakeIntReply(int64(deleted))
|
||||
}
|
||||
|
||||
func HLen(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
// parse args
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
return reply.MakeIntReply(int64(dict.Len()))
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
return reply.MakeIntReply(int64(dict.Len()))
|
||||
}
|
||||
|
||||
func HMSet(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) < 3 || len(args) % 2 != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
size := (len(args) - 1) / 2
|
||||
fields := make([]string, size)
|
||||
values := make([][]byte, size)
|
||||
for i := 0; i < size; i++ {
|
||||
fields[i] = string(args[2 * i + 1])
|
||||
values[i] = args[2 * i + 2]
|
||||
}
|
||||
// parse args
|
||||
if len(args) < 3 || len(args)%2 != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
size := (len(args) - 1) / 2
|
||||
fields := make([]string, size)
|
||||
values := make([][]byte, size)
|
||||
for i := 0; i < size; i++ {
|
||||
fields[i] = string(args[2*i+1])
|
||||
values[i] = args[2*i+2]
|
||||
}
|
||||
|
||||
// lock key
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
// lock key
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
|
||||
// get or init entity
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
// get or init entity
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
|
||||
// put data
|
||||
for i, field := range fields {
|
||||
value := values[i]
|
||||
dict.Put(field, value)
|
||||
}
|
||||
db.AddAof(makeAofCmd("hmset", args))
|
||||
return &reply.OkReply{}
|
||||
// put data
|
||||
for i, field := range fields {
|
||||
value := values[i]
|
||||
dict.Put(field, value)
|
||||
}
|
||||
db.AddAof(makeAofCmd("hmset", args))
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
|
||||
func HMGet(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
size := len(args) - 1
|
||||
fields := make([]string, size)
|
||||
for i := 0; i < size; i++ {
|
||||
fields[i] = string(args[i + 1])
|
||||
}
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
size := len(args) - 1
|
||||
fields := make([]string, size)
|
||||
for i := 0; i < size; i++ {
|
||||
fields[i] = string(args[i+1])
|
||||
}
|
||||
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
|
||||
// get entity
|
||||
result := make([][]byte, size)
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return reply.MakeMultiBulkReply(result)
|
||||
}
|
||||
// get entity
|
||||
result := make([][]byte, size)
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return reply.MakeMultiBulkReply(result)
|
||||
}
|
||||
|
||||
for i, field := range fields {
|
||||
value, ok := dict.Get(field)
|
||||
if !ok {
|
||||
result[i] = nil
|
||||
} else {
|
||||
bytes, _ := value.([]byte)
|
||||
result[i] = bytes
|
||||
}
|
||||
}
|
||||
return reply.MakeMultiBulkReply(result)
|
||||
for i, field := range fields {
|
||||
value, ok := dict.Get(field)
|
||||
if !ok {
|
||||
result[i] = nil
|
||||
} else {
|
||||
bytes, _ := value.([]byte)
|
||||
result[i] = bytes
|
||||
}
|
||||
}
|
||||
return reply.MakeMultiBulkReply(result)
|
||||
}
|
||||
|
||||
func HKeys(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
|
||||
fields := make([][]byte, dict.Len())
|
||||
i := 0
|
||||
dict.ForEach(func(key string, val interface{})bool {
|
||||
fields[i] = []byte(key)
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return reply.MakeMultiBulkReply(fields[:i])
|
||||
fields := make([][]byte, dict.Len())
|
||||
i := 0
|
||||
dict.ForEach(func(key string, val interface{}) bool {
|
||||
fields[i] = []byte(key)
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return reply.MakeMultiBulkReply(fields[:i])
|
||||
}
|
||||
|
||||
func HVals(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
|
||||
values := make([][]byte, dict.Len())
|
||||
i := 0
|
||||
dict.ForEach(func(key string, val interface{})bool {
|
||||
values[i], _ = val.([]byte)
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return reply.MakeMultiBulkReply(values[:i])
|
||||
values := make([][]byte, dict.Len())
|
||||
i := 0
|
||||
dict.ForEach(func(key string, val interface{}) bool {
|
||||
values[i], _ = val.([]byte)
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return reply.MakeMultiBulkReply(values[:i])
|
||||
}
|
||||
|
||||
func HGetAll(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
// get entity
|
||||
dict, errReply := db.getAsDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if dict == nil {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
|
||||
size := dict.Len()
|
||||
result := make([][]byte, size * 2)
|
||||
i := 0
|
||||
dict.ForEach(func(key string, val interface{})bool {
|
||||
result[i] = []byte(key)
|
||||
i++
|
||||
result[i], _ = val.([]byte)
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return reply.MakeMultiBulkReply(result[:i])
|
||||
size := dict.Len()
|
||||
result := make([][]byte, size*2)
|
||||
i := 0
|
||||
dict.ForEach(func(key string, val interface{}) bool {
|
||||
result[i] = []byte(key)
|
||||
i++
|
||||
result[i], _ = val.([]byte)
|
||||
i++
|
||||
return true
|
||||
})
|
||||
return reply.MakeMultiBulkReply(result[:i])
|
||||
}
|
||||
|
||||
func HIncrBy(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
rawDelta := string(args[2])
|
||||
delta, err := strconv.ParseInt(rawDelta, 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
rawDelta := string(args[2])
|
||||
delta, err := strconv.ParseInt(rawDelta, 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
|
||||
value, exists := dict.Get(field)
|
||||
if !exists {
|
||||
dict.Put(field, args[2])
|
||||
db.AddAof(makeAofCmd("hincrby", args))
|
||||
return reply.MakeBulkReply(args[2])
|
||||
} else {
|
||||
val, err := strconv.ParseInt(string(value.([]byte)), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR hash value is not an integer")
|
||||
}
|
||||
val += delta
|
||||
bytes := []byte(strconv.FormatInt(val, 10))
|
||||
dict.Put(field, bytes)
|
||||
db.AddAof(makeAofCmd("hincrby", args))
|
||||
return reply.MakeBulkReply(bytes)
|
||||
}
|
||||
value, exists := dict.Get(field)
|
||||
if !exists {
|
||||
dict.Put(field, args[2])
|
||||
db.AddAof(makeAofCmd("hincrby", args))
|
||||
return reply.MakeBulkReply(args[2])
|
||||
} else {
|
||||
val, err := strconv.ParseInt(string(value.([]byte)), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR hash value is not an integer")
|
||||
}
|
||||
val += delta
|
||||
bytes := []byte(strconv.FormatInt(val, 10))
|
||||
dict.Put(field, bytes)
|
||||
db.AddAof(makeAofCmd("hincrby", args))
|
||||
return reply.MakeBulkReply(bytes)
|
||||
}
|
||||
}
|
||||
|
||||
func HIncrByFloat(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
rawDelta := string(args[2])
|
||||
delta, err := decimal.NewFromString(rawDelta)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not a valid float")
|
||||
}
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
field := string(args[1])
|
||||
rawDelta := string(args[2])
|
||||
delta, err := decimal.NewFromString(rawDelta)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not a valid float")
|
||||
}
|
||||
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
// get or init entity
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
// get or init entity
|
||||
dict, _, errReply := db.getOrInitDict(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
|
||||
value, exists := dict.Get(field)
|
||||
if !exists {
|
||||
dict.Put(field, args[2])
|
||||
return reply.MakeBulkReply(args[2])
|
||||
} else {
|
||||
val, err := decimal.NewFromString(string(value.([]byte)))
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR hash value is not a float")
|
||||
}
|
||||
result := val.Add(delta)
|
||||
resultBytes:= []byte(result.String())
|
||||
dict.Put(field, resultBytes)
|
||||
db.AddAof(makeAofCmd("hincrbyfloat", args))
|
||||
return reply.MakeBulkReply(resultBytes)
|
||||
}
|
||||
}
|
||||
value, exists := dict.Get(field)
|
||||
if !exists {
|
||||
dict.Put(field, args[2])
|
||||
return reply.MakeBulkReply(args[2])
|
||||
} else {
|
||||
val, err := decimal.NewFromString(string(value.([]byte)))
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR hash value is not a float")
|
||||
}
|
||||
result := val.Add(delta)
|
||||
resultBytes := []byte(result.String())
|
||||
dict.Put(field, resultBytes)
|
||||
db.AddAof(makeAofCmd("hincrbyfloat", args))
|
||||
return reply.MakeBulkReply(resultBytes)
|
||||
}
|
||||
}
|
||||
|
704
src/db/list.go
704
src/db/list.go
@@ -1,439 +1,439 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
List "github.com/HDT3213/godis/src/datastruct/list"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
List "github.com/HDT3213/godis/src/datastruct/list"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func (db *DB) getAsList(key string)(*List.LinkedList, reply.ErrorReply) {
|
||||
entity, ok := db.Get(key)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
bytes, ok := entity.Data.(*List.LinkedList)
|
||||
if !ok {
|
||||
return nil, &reply.WrongTypeErrReply{}
|
||||
}
|
||||
return bytes, nil
|
||||
func (db *DB) getAsList(key string) (*List.LinkedList, reply.ErrorReply) {
|
||||
entity, ok := db.Get(key)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
bytes, ok := entity.Data.(*List.LinkedList)
|
||||
if !ok {
|
||||
return nil, &reply.WrongTypeErrReply{}
|
||||
}
|
||||
return bytes, nil
|
||||
}
|
||||
|
||||
func (db *DB) getOrInitList(key string)(list *List.LinkedList, inited bool, errReply reply.ErrorReply) {
|
||||
list, errReply = db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return nil, false, errReply
|
||||
}
|
||||
inited = false
|
||||
if list == nil {
|
||||
list = &List.LinkedList{}
|
||||
db.Put(key, &DataEntity{
|
||||
Data: list,
|
||||
})
|
||||
inited = true
|
||||
}
|
||||
return list, inited, nil
|
||||
func (db *DB) getOrInitList(key string) (list *List.LinkedList, inited bool, errReply reply.ErrorReply) {
|
||||
list, errReply = db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return nil, false, errReply
|
||||
}
|
||||
inited = false
|
||||
if list == nil {
|
||||
list = &List.LinkedList{}
|
||||
db.Put(key, &DataEntity{
|
||||
Data: list,
|
||||
})
|
||||
inited = true
|
||||
}
|
||||
return list, inited, nil
|
||||
}
|
||||
|
||||
func LIndex(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
index := int(index64)
|
||||
// parse args
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
index := int(index64)
|
||||
|
||||
// get entity
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
// get entity
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
|
||||
size := list.Len() // assert: size > 0
|
||||
if index < -1 * size {
|
||||
return &reply.NullBulkReply{}
|
||||
} else if index < 0 {
|
||||
index = size + index
|
||||
} else if index >= size {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
size := list.Len() // assert: size > 0
|
||||
if index < -1*size {
|
||||
return &reply.NullBulkReply{}
|
||||
} else if index < 0 {
|
||||
index = size + index
|
||||
} else if index >= size {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
|
||||
val, _ := list.Get(index).([]byte)
|
||||
return reply.MakeBulkReply(val)
|
||||
val, _ := list.Get(index).([]byte)
|
||||
return reply.MakeBulkReply(val)
|
||||
}
|
||||
|
||||
func LLen(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
// parse args
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
|
||||
size := int64(list.Len())
|
||||
return reply.MakeIntReply(size)
|
||||
size := int64(list.Len())
|
||||
return reply.MakeIntReply(size)
|
||||
}
|
||||
|
||||
func LPop(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
// parse args
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
// get data
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
// get data
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
|
||||
val, _ := list.Remove(0).([]byte)
|
||||
if list.Len() == 0 {
|
||||
db.Remove(key)
|
||||
}
|
||||
db.AddAof(makeAofCmd("lpop", args))
|
||||
return reply.MakeBulkReply(val)
|
||||
val, _ := list.Remove(0).([]byte)
|
||||
if list.Len() == 0 {
|
||||
db.Remove(key)
|
||||
}
|
||||
db.AddAof(makeAofCmd("lpop", args))
|
||||
return reply.MakeBulkReply(val)
|
||||
}
|
||||
|
||||
func LPush(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
values := args[1:]
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
values := args[1:]
|
||||
|
||||
// lock
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
// lock
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
|
||||
// get or init entity
|
||||
list, _, errReply := db.getOrInitList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
// get or init entity
|
||||
list, _, errReply := db.getOrInitList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
|
||||
// insert
|
||||
for _, value := range values {
|
||||
list.Insert(0, value)
|
||||
}
|
||||
// insert
|
||||
for _, value := range values {
|
||||
list.Insert(0, value)
|
||||
}
|
||||
|
||||
db.AddAof(makeAofCmd("lpush", args))
|
||||
return reply.MakeIntReply(int64(list.Len()))
|
||||
db.AddAof(makeAofCmd("lpush", args))
|
||||
return reply.MakeIntReply(int64(list.Len()))
|
||||
}
|
||||
|
||||
func LPushX(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
values := args[1:]
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
values := args[1:]
|
||||
|
||||
// lock
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
// lock
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
|
||||
// get or init entity
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
// get or init entity
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
|
||||
// insert
|
||||
for _, value := range values {
|
||||
list.Insert(0, value)
|
||||
}
|
||||
db.AddAof(makeAofCmd("lpushx", args))
|
||||
return reply.MakeIntReply(int64(list.Len()))
|
||||
// insert
|
||||
for _, value := range values {
|
||||
list.Insert(0, value)
|
||||
}
|
||||
db.AddAof(makeAofCmd("lpushx", args))
|
||||
return reply.MakeIntReply(int64(list.Len()))
|
||||
}
|
||||
|
||||
func LRange(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
start64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
start := int(start64)
|
||||
stop64, err := strconv.ParseInt(string(args[2]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
stop := int(stop64)
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
start64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
start := int(start64)
|
||||
stop64, err := strconv.ParseInt(string(args[2]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
stop := int(stop64)
|
||||
|
||||
// lock key
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
// lock key
|
||||
db.RLock(key)
|
||||
defer db.RUnLock(key)
|
||||
|
||||
// get data
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
// get data
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
|
||||
// compute index
|
||||
size := list.Len() // assert: size > 0
|
||||
if start < -1 * size {
|
||||
start = 0
|
||||
} else if start < 0 {
|
||||
start = size + start
|
||||
} else if start >= size {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
if stop < -1 * size {
|
||||
stop = 0
|
||||
} else if stop < 0 {
|
||||
stop = size + stop + 1
|
||||
} else if stop < size {
|
||||
stop = stop + 1
|
||||
} else {
|
||||
stop = size
|
||||
}
|
||||
if stop < start {
|
||||
stop = start
|
||||
}
|
||||
// compute index
|
||||
size := list.Len() // assert: size > 0
|
||||
if start < -1*size {
|
||||
start = 0
|
||||
} else if start < 0 {
|
||||
start = size + start
|
||||
} else if start >= size {
|
||||
return &reply.EmptyMultiBulkReply{}
|
||||
}
|
||||
if stop < -1*size {
|
||||
stop = 0
|
||||
} else if stop < 0 {
|
||||
stop = size + stop + 1
|
||||
} else if stop < size {
|
||||
stop = stop + 1
|
||||
} else {
|
||||
stop = size
|
||||
}
|
||||
if stop < start {
|
||||
stop = start
|
||||
}
|
||||
|
||||
// assert: start in [0, size - 1], stop in [start, size]
|
||||
slice := list.Range(start, stop)
|
||||
result := make([][]byte, len(slice))
|
||||
for i, raw := range slice {
|
||||
bytes, _ := raw.([]byte)
|
||||
result[i] = bytes
|
||||
}
|
||||
return reply.MakeMultiBulkReply(result)
|
||||
// assert: start in [0, size - 1], stop in [start, size]
|
||||
slice := list.Range(start, stop)
|
||||
result := make([][]byte, len(slice))
|
||||
for i, raw := range slice {
|
||||
bytes, _ := raw.([]byte)
|
||||
result[i] = bytes
|
||||
}
|
||||
return reply.MakeMultiBulkReply(result)
|
||||
}
|
||||
|
||||
func LRem(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
count64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
count := int(count64)
|
||||
value := args[2]
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
count64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
count := int(count64)
|
||||
value := args[2]
|
||||
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
// get data entity
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
// get data entity
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
|
||||
var removed int
|
||||
if count == 0 {
|
||||
removed = list.RemoveAllByVal(value)
|
||||
} else if count > 0 {
|
||||
removed = list.RemoveByVal(value, count)
|
||||
} else {
|
||||
removed = list.ReverseRemoveByVal(value, -count)
|
||||
}
|
||||
var removed int
|
||||
if count == 0 {
|
||||
removed = list.RemoveAllByVal(value)
|
||||
} else if count > 0 {
|
||||
removed = list.RemoveByVal(value, count)
|
||||
} else {
|
||||
removed = list.ReverseRemoveByVal(value, -count)
|
||||
}
|
||||
|
||||
if list.Len() == 0 {
|
||||
db.Remove(key)
|
||||
}
|
||||
if removed > 0 {
|
||||
db.AddAof(makeAofCmd("lrem", args))
|
||||
}
|
||||
if list.Len() == 0 {
|
||||
db.Remove(key)
|
||||
}
|
||||
if removed > 0 {
|
||||
db.AddAof(makeAofCmd("lrem", args))
|
||||
}
|
||||
|
||||
return reply.MakeIntReply(int64(removed))
|
||||
return reply.MakeIntReply(int64(removed))
|
||||
}
|
||||
|
||||
func LSet(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
index := int(index64)
|
||||
value := args[2]
|
||||
// parse args
|
||||
if len(args) != 3 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||
if err != nil {
|
||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||
}
|
||||
index := int(index64)
|
||||
value := args[2]
|
||||
|
||||
// lock
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
// lock
|
||||
db.Locker.Lock(key)
|
||||
defer db.Locker.UnLock(key)
|
||||
|
||||
// get data
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeErrReply("ERR no such key")
|
||||
}
|
||||
// get data
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeErrReply("ERR no such key")
|
||||
}
|
||||
|
||||
size := list.Len() // assert: size > 0
|
||||
if index < -1 * size {
|
||||
return reply.MakeErrReply("ERR index out of range")
|
||||
} else if index < 0 {
|
||||
index = size + index
|
||||
} else if index >= size {
|
||||
return reply.MakeErrReply("ERR index out of range")
|
||||
}
|
||||
size := list.Len() // assert: size > 0
|
||||
if index < -1*size {
|
||||
return reply.MakeErrReply("ERR index out of range")
|
||||
} else if index < 0 {
|
||||
index = size + index
|
||||
} else if index >= size {
|
||||
return reply.MakeErrReply("ERR index out of range")
|
||||
}
|
||||
|
||||
list.Set(index, value)
|
||||
db.AddAof(makeAofCmd("lset", args))
|
||||
return &reply.OkReply{}
|
||||
list.Set(index, value)
|
||||
db.AddAof(makeAofCmd("lset", args))
|
||||
return &reply.OkReply{}
|
||||
}
|
||||
|
||||
func RPop(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
// parse args
|
||||
if len(args) != 1 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
// get data
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
// get data
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
|
||||
val, _ := list.RemoveLast().([]byte)
|
||||
if list.Len() == 0 {
|
||||
db.Remove(key)
|
||||
}
|
||||
db.AddAof(makeAofCmd("rpop", args))
|
||||
return reply.MakeBulkReply(val)
|
||||
val, _ := list.RemoveLast().([]byte)
|
||||
if list.Len() == 0 {
|
||||
db.Remove(key)
|
||||
}
|
||||
db.AddAof(makeAofCmd("rpop", args))
|
||||
return reply.MakeBulkReply(val)
|
||||
}
|
||||
|
||||
func RPopLPush(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command")
|
||||
}
|
||||
sourceKey := string(args[0])
|
||||
destKey := string(args[1])
|
||||
if len(args) != 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command")
|
||||
}
|
||||
sourceKey := string(args[0])
|
||||
destKey := string(args[1])
|
||||
|
||||
// lock
|
||||
db.Locks(sourceKey, destKey)
|
||||
defer db.UnLocks(sourceKey, destKey)
|
||||
// lock
|
||||
db.Locks(sourceKey, destKey)
|
||||
defer db.UnLocks(sourceKey, destKey)
|
||||
|
||||
// get source entity
|
||||
sourceList, errReply := db.getAsList(sourceKey)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if sourceList == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
// get source entity
|
||||
sourceList, errReply := db.getAsList(sourceKey)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if sourceList == nil {
|
||||
return &reply.NullBulkReply{}
|
||||
}
|
||||
|
||||
// get dest entity
|
||||
destList, _, errReply := db.getOrInitList(destKey)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
// get dest entity
|
||||
destList, _, errReply := db.getOrInitList(destKey)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
|
||||
// pop and push
|
||||
val, _ := sourceList.RemoveLast().([]byte)
|
||||
destList.Insert(0, val)
|
||||
// pop and push
|
||||
val, _ := sourceList.RemoveLast().([]byte)
|
||||
destList.Insert(0, val)
|
||||
|
||||
if sourceList.Len() == 0 {
|
||||
db.Remove(sourceKey)
|
||||
}
|
||||
if sourceList.Len() == 0 {
|
||||
db.Remove(sourceKey)
|
||||
}
|
||||
|
||||
db.AddAof(makeAofCmd("rpoplpush", args))
|
||||
return reply.MakeBulkReply(val)
|
||||
db.AddAof(makeAofCmd("rpoplpush", args))
|
||||
return reply.MakeBulkReply(val)
|
||||
}
|
||||
|
||||
func RPush(db *DB, args [][]byte) redis.Reply {
|
||||
// parse args
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
values := args[1:]
|
||||
// parse args
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
values := args[1:]
|
||||
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
// get or init entity
|
||||
list, _, errReply := db.getOrInitList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
// get or init entity
|
||||
list, _, errReply := db.getOrInitList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
|
||||
// put list
|
||||
for _, value := range values {
|
||||
list.Add(value)
|
||||
}
|
||||
db.AddAof(makeAofCmd("rpush", args))
|
||||
return reply.MakeIntReply(int64(list.Len()))
|
||||
// put list
|
||||
for _, value := range values {
|
||||
list.Add(value)
|
||||
}
|
||||
db.AddAof(makeAofCmd("rpush", args))
|
||||
return reply.MakeIntReply(int64(list.Len()))
|
||||
}
|
||||
|
||||
func RPushX(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
values := args[1:]
|
||||
if len(args) < 2 {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
|
||||
}
|
||||
key := string(args[0])
|
||||
values := args[1:]
|
||||
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
// lock
|
||||
db.Lock(key)
|
||||
defer db.UnLock(key)
|
||||
|
||||
// get or init entity
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
// get or init entity
|
||||
list, errReply := db.getAsList(key)
|
||||
if errReply != nil {
|
||||
return errReply
|
||||
}
|
||||
if list == nil {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
|
||||
// put list
|
||||
for _, value := range values {
|
||||
list.Add(value)
|
||||
}
|
||||
db.AddAof(makeAofCmd("rpushx", args))
|
||||
// put list
|
||||
for _, value := range values {
|
||||
list.Add(value)
|
||||
}
|
||||
db.AddAof(makeAofCmd("rpushx", args))
|
||||
|
||||
return reply.MakeIntReply(int64(list.Len()))
|
||||
}
|
||||
return reply.MakeIntReply(int64(list.Len()))
|
||||
}
|
||||
|
180
src/db/router.go
180
src/db/router.go
@@ -1,102 +1,102 @@
|
||||
package db
|
||||
|
||||
func MakeRouter()map[string]CmdFunc {
|
||||
routerMap := make(map[string]CmdFunc)
|
||||
routerMap["ping"] = Ping
|
||||
func MakeRouter() map[string]CmdFunc {
|
||||
routerMap := make(map[string]CmdFunc)
|
||||
routerMap["ping"] = Ping
|
||||
|
||||
routerMap["del"] = Del
|
||||
routerMap["expire"] = Expire
|
||||
routerMap["expireat"] = ExpireAt
|
||||
routerMap["pexpire"] = PExpire
|
||||
routerMap["pexpireat"] = PExpireAt
|
||||
routerMap["ttl"] = TTL
|
||||
routerMap["pttl"] = PTTL
|
||||
routerMap["persist"] = Persist
|
||||
routerMap["exists"] = Exists
|
||||
routerMap["type"] = Type
|
||||
routerMap["rename"] = Rename
|
||||
routerMap["renamenx"] = RenameNx
|
||||
routerMap["del"] = Del
|
||||
routerMap["expire"] = Expire
|
||||
routerMap["expireat"] = ExpireAt
|
||||
routerMap["pexpire"] = PExpire
|
||||
routerMap["pexpireat"] = PExpireAt
|
||||
routerMap["ttl"] = TTL
|
||||
routerMap["pttl"] = PTTL
|
||||
routerMap["persist"] = Persist
|
||||
routerMap["exists"] = Exists
|
||||
routerMap["type"] = Type
|
||||
routerMap["rename"] = Rename
|
||||
routerMap["renamenx"] = RenameNx
|
||||
|
||||
routerMap["set"] = Set
|
||||
routerMap["setnx"] = SetNX
|
||||
routerMap["setex"] = SetEX
|
||||
routerMap["psetex"] = PSetEX
|
||||
routerMap["mset"] = MSet
|
||||
routerMap["mget"] = MGet
|
||||
routerMap["msetnx"] = MSetNX
|
||||
routerMap["get"] = Get
|
||||
routerMap["getset"] = GetSet
|
||||
routerMap["incr"] = Incr
|
||||
routerMap["incrby"] = IncrBy
|
||||
routerMap["incrbyfloat"] = IncrByFloat
|
||||
routerMap["decr"] = Decr
|
||||
routerMap["decrby"] = DecrBy
|
||||
routerMap["set"] = Set
|
||||
routerMap["setnx"] = SetNX
|
||||
routerMap["setex"] = SetEX
|
||||
routerMap["psetex"] = PSetEX
|
||||
routerMap["mset"] = MSet
|
||||
routerMap["mget"] = MGet
|
||||
routerMap["msetnx"] = MSetNX
|
||||
routerMap["get"] = Get
|
||||
routerMap["getset"] = GetSet
|
||||
routerMap["incr"] = Incr
|
||||
routerMap["incrby"] = IncrBy
|
||||
routerMap["incrbyfloat"] = IncrByFloat
|
||||
routerMap["decr"] = Decr
|
||||
routerMap["decrby"] = DecrBy
|
||||
|
||||
routerMap["lpush"] = LPush
|
||||
routerMap["lpushx"] = LPushX
|
||||
routerMap["rpush"] = RPush
|
||||
routerMap["rpushx"] = RPushX
|
||||
routerMap["lpop"] = LPop
|
||||
routerMap["rpop"] = RPop
|
||||
routerMap["rpoplpush"] = RPopLPush
|
||||
routerMap["lrem"] = LRem
|
||||
routerMap["llen"] = LLen
|
||||
routerMap["lindex"] = LIndex
|
||||
routerMap["lset"] = LSet
|
||||
routerMap["lrange"] = LRange
|
||||
routerMap["lpush"] = LPush
|
||||
routerMap["lpushx"] = LPushX
|
||||
routerMap["rpush"] = RPush
|
||||
routerMap["rpushx"] = RPushX
|
||||
routerMap["lpop"] = LPop
|
||||
routerMap["rpop"] = RPop
|
||||
routerMap["rpoplpush"] = RPopLPush
|
||||
routerMap["lrem"] = LRem
|
||||
routerMap["llen"] = LLen
|
||||
routerMap["lindex"] = LIndex
|
||||
routerMap["lset"] = LSet
|
||||
routerMap["lrange"] = LRange
|
||||
|
||||
routerMap["hset"] = HSet
|
||||
routerMap["hsetnx"] = HSetNX
|
||||
routerMap["hget"] = HGet
|
||||
routerMap["hexists"] = HExists
|
||||
routerMap["hdel"] = HDel
|
||||
routerMap["hlen"] = HLen
|
||||
routerMap["hmget"] = HMGet
|
||||
routerMap["hmset"] = HMSet
|
||||
routerMap["hkeys"] = HKeys
|
||||
routerMap["hvals"] = HVals
|
||||
routerMap["hgetall"] = HGetAll
|
||||
routerMap["hincrby"] = HIncrBy
|
||||
routerMap["hincrbyfloat"] = HIncrByFloat
|
||||
routerMap["hset"] = HSet
|
||||
routerMap["hsetnx"] = HSetNX
|
||||
routerMap["hget"] = HGet
|
||||
routerMap["hexists"] = HExists
|
||||
routerMap["hdel"] = HDel
|
||||
routerMap["hlen"] = HLen
|
||||
routerMap["hmget"] = HMGet
|
||||
routerMap["hmset"] = HMSet
|
||||
routerMap["hkeys"] = HKeys
|
||||
routerMap["hvals"] = HVals
|
||||
routerMap["hgetall"] = HGetAll
|
||||
routerMap["hincrby"] = HIncrBy
|
||||
routerMap["hincrbyfloat"] = HIncrByFloat
|
||||
|
||||
routerMap["sadd"] = SAdd
|
||||
routerMap["sismember"] = SIsMember
|
||||
routerMap["srem"] = SRem
|
||||
routerMap["scard"] = SCard
|
||||
routerMap["smembers"] = SMembers
|
||||
routerMap["sinter"] = SInter
|
||||
routerMap["sinterstore"] = SInterStore
|
||||
routerMap["sunion"] = SUnion
|
||||
routerMap["sunionstore"] = SUnionStore
|
||||
routerMap["sdiff"] = SDiff
|
||||
routerMap["sdiffstore"] = SDiffStore
|
||||
routerMap["srandmember"] = SRandMember
|
||||
routerMap["sadd"] = SAdd
|
||||
routerMap["sismember"] = SIsMember
|
||||
routerMap["srem"] = SRem
|
||||
routerMap["scard"] = SCard
|
||||
routerMap["smembers"] = SMembers
|
||||
routerMap["sinter"] = SInter
|
||||
routerMap["sinterstore"] = SInterStore
|
||||
routerMap["sunion"] = SUnion
|
||||
routerMap["sunionstore"] = SUnionStore
|
||||
routerMap["sdiff"] = SDiff
|
||||
routerMap["sdiffstore"] = SDiffStore
|
||||
routerMap["srandmember"] = SRandMember
|
||||
|
||||
routerMap["zadd"] = ZAdd
|
||||
routerMap["zscore"] = ZScore
|
||||
routerMap["zincrby"] = ZIncrBy
|
||||
routerMap["zrank"] = ZRank
|
||||
routerMap["zcount"] = ZCount
|
||||
routerMap["zrevrank"] = ZRevRank
|
||||
routerMap["zcard"] = ZCard
|
||||
routerMap["zrange"] = ZRange
|
||||
routerMap["zrevrange"] = ZRevRange
|
||||
routerMap["zrangebyscore"] = ZRangeByScore
|
||||
routerMap["zrevrangebyscore"] = ZRevRangeByScore
|
||||
routerMap["zrem"] = ZRem
|
||||
routerMap["zremrangebyscore"] = ZRemRangeByScore
|
||||
routerMap["zremrangebyrank"] = ZRemRangeByRank
|
||||
routerMap["zadd"] = ZAdd
|
||||
routerMap["zscore"] = ZScore
|
||||
routerMap["zincrby"] = ZIncrBy
|
||||
routerMap["zrank"] = ZRank
|
||||
routerMap["zcount"] = ZCount
|
||||
routerMap["zrevrank"] = ZRevRank
|
||||
routerMap["zcard"] = ZCard
|
||||
routerMap["zrange"] = ZRange
|
||||
routerMap["zrevrange"] = ZRevRange
|
||||
routerMap["zrangebyscore"] = ZRangeByScore
|
||||
routerMap["zrevrangebyscore"] = ZRevRangeByScore
|
||||
routerMap["zrem"] = ZRem
|
||||
routerMap["zremrangebyscore"] = ZRemRangeByScore
|
||||
routerMap["zremrangebyrank"] = ZRemRangeByRank
|
||||
|
||||
routerMap["geoadd"] = GeoAdd
|
||||
routerMap["geopos"] = GeoPos
|
||||
routerMap["geodist"] = GeoDist
|
||||
routerMap["geohash"] = GeoHash
|
||||
routerMap["georadius"] = GeoRadius
|
||||
routerMap["georadiusbymember"] = GeoRadiusByMember
|
||||
routerMap["geoadd"] = GeoAdd
|
||||
routerMap["geopos"] = GeoPos
|
||||
routerMap["geodist"] = GeoDist
|
||||
routerMap["geohash"] = GeoHash
|
||||
routerMap["georadius"] = GeoRadius
|
||||
routerMap["georadiusbymember"] = GeoRadiusByMember
|
||||
|
||||
routerMap["flushdb"] = FlushDB
|
||||
routerMap["flushall"] = FlushAll
|
||||
routerMap["keys"] = Keys
|
||||
routerMap["flushdb"] = FlushDB
|
||||
routerMap["flushall"] = FlushAll
|
||||
routerMap["keys"] = Keys
|
||||
|
||||
return routerMap
|
||||
return routerMap
|
||||
}
|
||||
|
@@ -1,16 +1,16 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
)
|
||||
|
||||
func Ping(db *DB, args [][]byte) redis.Reply {
|
||||
if len(args) == 0 {
|
||||
return &reply.PongReply{}
|
||||
} else if len(args) == 1 {
|
||||
return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
|
||||
} else {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
|
||||
}
|
||||
}
|
||||
if len(args) == 0 {
|
||||
return &reply.PongReply{}
|
||||
} else if len(args) == 1 {
|
||||
return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
|
||||
} else {
|
||||
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
|
||||
}
|
||||
}
|
||||
|
1010
src/db/sortedset.go
1010
src/db/sortedset.go
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@ package db
|
||||
import "github.com/HDT3213/godis/src/interface/redis"
|
||||
|
||||
type DB interface {
|
||||
Exec(client redis.Connection, args [][]byte) redis.Reply
|
||||
AfterClientClose(c redis.Connection)
|
||||
Close()
|
||||
Exec(client redis.Connection, args [][]byte) redis.Reply
|
||||
AfterClientClose(c redis.Connection)
|
||||
Close()
|
||||
}
|
||||
|
@@ -1,11 +1,11 @@
|
||||
package redis
|
||||
|
||||
type Connection interface {
|
||||
Write([]byte) error
|
||||
Write([]byte) error
|
||||
|
||||
// client should keep its subscribing channels
|
||||
SubsChannel(channel string)
|
||||
UnSubsChannel(channel string)
|
||||
SubsCount()int
|
||||
GetChannels()[]string
|
||||
// client should keep its subscribing channels
|
||||
SubsChannel(channel string)
|
||||
UnSubsChannel(channel string)
|
||||
SubsCount() int
|
||||
GetChannels() []string
|
||||
}
|
||||
|
@@ -1,5 +1,5 @@
|
||||
package redis
|
||||
|
||||
type Reply interface {
|
||||
ToBytes()[]byte
|
||||
ToBytes() []byte
|
||||
}
|
||||
|
@@ -1,13 +1,13 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type HandleFunc func(ctx context.Context, conn net.Conn)
|
||||
|
||||
type Handler interface {
|
||||
Handle(ctx context.Context, conn net.Conn)
|
||||
Close()error
|
||||
}
|
||||
Handle(ctx context.Context, conn net.Conn)
|
||||
Close() error
|
||||
}
|
||||
|
@@ -1,80 +1,80 @@
|
||||
package consistenthash
|
||||
|
||||
import (
|
||||
"hash/crc32"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"hash/crc32"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type HashFunc func(data []byte) uint32
|
||||
|
||||
type Map struct {
|
||||
hashFunc HashFunc
|
||||
replicas int
|
||||
keys []int // sorted
|
||||
hashMap map[int]string
|
||||
hashFunc HashFunc
|
||||
replicas int
|
||||
keys []int // sorted
|
||||
hashMap map[int]string
|
||||
}
|
||||
|
||||
func New(replicas int, fn HashFunc) *Map {
|
||||
m := &Map{
|
||||
replicas: replicas,
|
||||
hashFunc: fn,
|
||||
hashMap: make(map[int]string),
|
||||
}
|
||||
if m.hashFunc == nil {
|
||||
m.hashFunc = crc32.ChecksumIEEE
|
||||
}
|
||||
return m
|
||||
m := &Map{
|
||||
replicas: replicas,
|
||||
hashFunc: fn,
|
||||
hashMap: make(map[int]string),
|
||||
}
|
||||
if m.hashFunc == nil {
|
||||
m.hashFunc = crc32.ChecksumIEEE
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *Map) IsEmpty() bool {
|
||||
return len(m.keys) == 0
|
||||
return len(m.keys) == 0
|
||||
}
|
||||
|
||||
func (m *Map) Add(keys ...string) {
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
for i := 0; i < m.replicas; i++ {
|
||||
hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key)))
|
||||
m.keys = append(m.keys, hash)
|
||||
m.hashMap[hash] = key
|
||||
}
|
||||
}
|
||||
sort.Ints(m.keys)
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
for i := 0; i < m.replicas; i++ {
|
||||
hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key)))
|
||||
m.keys = append(m.keys, hash)
|
||||
m.hashMap[hash] = key
|
||||
}
|
||||
}
|
||||
sort.Ints(m.keys)
|
||||
}
|
||||
|
||||
// support hash tag
|
||||
func getPartitionKey(key string) string {
|
||||
beg := strings.Index(key, "{")
|
||||
if beg == -1 {
|
||||
return key
|
||||
}
|
||||
end := strings.Index(key, "}")
|
||||
if end == -1 || end == beg+1 {
|
||||
return key
|
||||
}
|
||||
return key[beg+1 : end]
|
||||
beg := strings.Index(key, "{")
|
||||
if beg == -1 {
|
||||
return key
|
||||
}
|
||||
end := strings.Index(key, "}")
|
||||
if end == -1 || end == beg+1 {
|
||||
return key
|
||||
}
|
||||
return key[beg+1 : end]
|
||||
}
|
||||
|
||||
// Get gets the closest item in the hash to the provided key.
|
||||
func (m *Map) Get(key string) string {
|
||||
if m.IsEmpty() {
|
||||
return ""
|
||||
}
|
||||
if m.IsEmpty() {
|
||||
return ""
|
||||
}
|
||||
|
||||
partitionKey := getPartitionKey(key)
|
||||
hash := int(m.hashFunc([]byte(partitionKey)))
|
||||
partitionKey := getPartitionKey(key)
|
||||
hash := int(m.hashFunc([]byte(partitionKey)))
|
||||
|
||||
// Binary search for appropriate replica.
|
||||
idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash })
|
||||
// Binary search for appropriate replica.
|
||||
idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash })
|
||||
|
||||
// Means we have cycled back to the first replica.
|
||||
if idx == len(m.keys) {
|
||||
idx = 0
|
||||
}
|
||||
// Means we have cycled back to the first replica.
|
||||
if idx == len(m.keys) {
|
||||
idx = 0
|
||||
}
|
||||
|
||||
return m.hashMap[m.keys[idx]]
|
||||
return m.hashMap[m.keys[idx]]
|
||||
}
|
||||
|
@@ -1,78 +1,78 @@
|
||||
package files
|
||||
|
||||
import (
|
||||
"mime/multipart"
|
||||
"io/ioutil"
|
||||
"path"
|
||||
"os"
|
||||
"fmt"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
func GetSize(f multipart.File) (int, error) {
|
||||
content, err := ioutil.ReadAll(f)
|
||||
content, err := ioutil.ReadAll(f)
|
||||
|
||||
return len(content), err
|
||||
return len(content), err
|
||||
}
|
||||
|
||||
func GetExt(fileName string) string {
|
||||
return path.Ext(fileName)
|
||||
return path.Ext(fileName)
|
||||
}
|
||||
|
||||
func CheckNotExist(src string) bool {
|
||||
_, err := os.Stat(src)
|
||||
_, err := os.Stat(src)
|
||||
|
||||
return os.IsNotExist(err)
|
||||
return os.IsNotExist(err)
|
||||
}
|
||||
|
||||
func CheckPermission(src string) bool {
|
||||
_, err := os.Stat(src)
|
||||
_, err := os.Stat(src)
|
||||
|
||||
return os.IsPermission(err)
|
||||
return os.IsPermission(err)
|
||||
}
|
||||
|
||||
func IsNotExistMkDir(src string) error {
|
||||
if notExist := CheckNotExist(src); notExist == true {
|
||||
if err := MkDir(src); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if notExist := CheckNotExist(src); notExist == true {
|
||||
if err := MkDir(src); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func MkDir(src string) error {
|
||||
err := os.MkdirAll(src, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err := os.MkdirAll(src, os.ModePerm)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func Open(name string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
f, err := os.OpenFile(name, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := os.OpenFile(name, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f, nil
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func MustOpen(fileName, dir string) (*os.File, error) {
|
||||
perm := CheckPermission(dir)
|
||||
if perm == true {
|
||||
return nil, fmt.Errorf("permission denied dir: %s", dir)
|
||||
}
|
||||
perm := CheckPermission(dir)
|
||||
if perm == true {
|
||||
return nil, fmt.Errorf("permission denied dir: %s", dir)
|
||||
}
|
||||
|
||||
err := IsNotExistMkDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
|
||||
}
|
||||
err := IsNotExistMkDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
|
||||
}
|
||||
|
||||
f, err := Open(dir + string(os.PathSeparator) + fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to open file, err: %s", err)
|
||||
}
|
||||
f, err := Open(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to open file, err: %s", err)
|
||||
}
|
||||
|
||||
return f, nil
|
||||
return f, nil
|
||||
}
|
||||
|
@@ -1,9 +1,9 @@
|
||||
package geohash
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"bytes"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
var bits = []uint8{128, 64, 32, 16, 8, 4, 2, 1}
|
||||
@@ -13,95 +13,95 @@ const defaultBitSize = 64 // 32 bits for latitude, another 32 bits for longitude
|
||||
|
||||
// return: geohash, box
|
||||
func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64) {
|
||||
box := [2][2]float64{
|
||||
{-180, 180}, // lng
|
||||
{-90, 90}, // lat
|
||||
}
|
||||
pos := [2]float64{longitude, latitude}
|
||||
hash := &bytes.Buffer{}
|
||||
bit := 0
|
||||
var precision uint = 0
|
||||
code := uint8(0)
|
||||
for precision < bitSize {
|
||||
for direction, val := range pos {
|
||||
mid := (box[direction][0] + box[direction][1]) / 2
|
||||
if val < mid {
|
||||
box[direction][1] = mid
|
||||
} else {
|
||||
box[direction][0] = mid
|
||||
code |= bits[bit]
|
||||
}
|
||||
bit++
|
||||
if bit == 8 {
|
||||
hash.WriteByte(code)
|
||||
bit = 0
|
||||
code = 0
|
||||
}
|
||||
precision++
|
||||
if precision == bitSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// precision%8 > 0
|
||||
if code > 0 {
|
||||
hash.WriteByte(code)
|
||||
}
|
||||
return hash.Bytes(), box
|
||||
box := [2][2]float64{
|
||||
{-180, 180}, // lng
|
||||
{-90, 90}, // lat
|
||||
}
|
||||
pos := [2]float64{longitude, latitude}
|
||||
hash := &bytes.Buffer{}
|
||||
bit := 0
|
||||
var precision uint = 0
|
||||
code := uint8(0)
|
||||
for precision < bitSize {
|
||||
for direction, val := range pos {
|
||||
mid := (box[direction][0] + box[direction][1]) / 2
|
||||
if val < mid {
|
||||
box[direction][1] = mid
|
||||
} else {
|
||||
box[direction][0] = mid
|
||||
code |= bits[bit]
|
||||
}
|
||||
bit++
|
||||
if bit == 8 {
|
||||
hash.WriteByte(code)
|
||||
bit = 0
|
||||
code = 0
|
||||
}
|
||||
precision++
|
||||
if precision == bitSize {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
// precision%8 > 0
|
||||
if code > 0 {
|
||||
hash.WriteByte(code)
|
||||
}
|
||||
return hash.Bytes(), box
|
||||
}
|
||||
|
||||
func Encode(latitude, longitude float64) uint64 {
|
||||
buf, _ := encode0(latitude, longitude, defaultBitSize)
|
||||
return binary.BigEndian.Uint64(buf)
|
||||
buf, _ := encode0(latitude, longitude, defaultBitSize)
|
||||
return binary.BigEndian.Uint64(buf)
|
||||
}
|
||||
|
||||
func decode0(hash []byte) [][]float64 {
|
||||
box := [][]float64{
|
||||
{-180, 180},
|
||||
{-90, 90},
|
||||
}
|
||||
direction := 0
|
||||
for i := 0; i < len(hash); i++ {
|
||||
code := hash[i]
|
||||
for j := 0; j < len(bits); j++ {
|
||||
mid := (box[direction][0] + box[direction][1]) / 2
|
||||
mask := bits[j]
|
||||
if mask&code > 0 {
|
||||
box[direction][0] = mid
|
||||
} else {
|
||||
box[direction][1] = mid
|
||||
}
|
||||
direction = (direction + 1) % 2
|
||||
}
|
||||
}
|
||||
return box
|
||||
box := [][]float64{
|
||||
{-180, 180},
|
||||
{-90, 90},
|
||||
}
|
||||
direction := 0
|
||||
for i := 0; i < len(hash); i++ {
|
||||
code := hash[i]
|
||||
for j := 0; j < len(bits); j++ {
|
||||
mid := (box[direction][0] + box[direction][1]) / 2
|
||||
mask := bits[j]
|
||||
if mask&code > 0 {
|
||||
box[direction][0] = mid
|
||||
} else {
|
||||
box[direction][1] = mid
|
||||
}
|
||||
direction = (direction + 1) % 2
|
||||
}
|
||||
}
|
||||
return box
|
||||
}
|
||||
|
||||
func Decode(code uint64) (float64, float64) {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, code)
|
||||
box := decode0(buf)
|
||||
lng := float64(box[0][0]+box[0][1]) / 2
|
||||
lat := float64(box[1][0]+box[1][1]) / 2
|
||||
return lat, lng
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, code)
|
||||
box := decode0(buf)
|
||||
lng := float64(box[0][0]+box[0][1]) / 2
|
||||
lat := float64(box[1][0]+box[1][1]) / 2
|
||||
return lat, lng
|
||||
}
|
||||
|
||||
func ToString(buf []byte) string {
|
||||
return enc.EncodeToString(buf)
|
||||
return enc.EncodeToString(buf)
|
||||
}
|
||||
|
||||
func ToInt(buf []byte) uint64 {
|
||||
// padding
|
||||
if len(buf) < 8 {
|
||||
buf2 := make([]byte, 8)
|
||||
copy(buf2, buf)
|
||||
return binary.BigEndian.Uint64(buf2)
|
||||
}
|
||||
return binary.BigEndian.Uint64(buf)
|
||||
// padding
|
||||
if len(buf) < 8 {
|
||||
buf2 := make([]byte, 8)
|
||||
copy(buf2, buf)
|
||||
return binary.BigEndian.Uint64(buf2)
|
||||
}
|
||||
return binary.BigEndian.Uint64(buf)
|
||||
}
|
||||
|
||||
func FromInt(code uint64) []byte {
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, code)
|
||||
return buf
|
||||
}
|
||||
buf := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(buf, code)
|
||||
return buf
|
||||
}
|
||||
|
@@ -1,39 +1,39 @@
|
||||
package geohash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestToRange(t *testing.T) {
|
||||
neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}
|
||||
range_ := ToRange(neighbor, 36)
|
||||
expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00})
|
||||
expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00})
|
||||
if expectedLower != range_[0] {
|
||||
t.Error("incorrect lower")
|
||||
}
|
||||
if expectedUpper != range_[1] {
|
||||
t.Error("incorrect upper")
|
||||
}
|
||||
neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}
|
||||
range_ := ToRange(neighbor, 36)
|
||||
expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00})
|
||||
expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00})
|
||||
if expectedLower != range_[0] {
|
||||
t.Error("incorrect lower")
|
||||
}
|
||||
if expectedUpper != range_[1] {
|
||||
t.Error("incorrect upper")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncode(t *testing.T) {
|
||||
lat0 := 48.669
|
||||
lng0 := -4.32913
|
||||
hash := Encode(lat0, lng0)
|
||||
str := ToString(FromInt(hash))
|
||||
if str != "gbsuv7zt7zntw" {
|
||||
t.Error("encode error")
|
||||
}
|
||||
lat, lng := Decode(hash)
|
||||
if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 {
|
||||
t.Error("decode error")
|
||||
}
|
||||
lat0 := 48.669
|
||||
lng0 := -4.32913
|
||||
hash := Encode(lat0, lng0)
|
||||
str := ToString(FromInt(hash))
|
||||
if str != "gbsuv7zt7zntw" {
|
||||
t.Error("encode error")
|
||||
}
|
||||
lat, lng := Decode(hash)
|
||||
if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 {
|
||||
t.Error("decode error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetNeighbours(t *testing.T) {
|
||||
ranges := GetNeighbours(90, 180, 630*1000)
|
||||
fmt.Printf("%#v", ranges)
|
||||
}
|
||||
ranges := GetNeighbours(90, 180, 630*1000)
|
||||
fmt.Printf("%#v", ranges)
|
||||
}
|
||||
|
@@ -3,134 +3,134 @@ package geohash
|
||||
import "math"
|
||||
|
||||
const (
|
||||
DR = math.Pi / 180.0
|
||||
EarthRadius = 6372797.560856
|
||||
MercatorMax = 20037726.37 // pi * EarthRadius
|
||||
MercatorMin = -20037726.37
|
||||
DR = math.Pi / 180.0
|
||||
EarthRadius = 6372797.560856
|
||||
MercatorMax = 20037726.37 // pi * EarthRadius
|
||||
MercatorMin = -20037726.37
|
||||
)
|
||||
|
||||
func degRad(ang float64) float64 {
|
||||
return ang * DR
|
||||
return ang * DR
|
||||
}
|
||||
|
||||
func radDeg(ang float64) float64 {
|
||||
return ang / DR
|
||||
return ang / DR
|
||||
}
|
||||
|
||||
func getBoundingBox(latitude float64, longitude float64, radiusMeters float64) (
|
||||
minLat, maxLat, minLng, maxLng float64) {
|
||||
minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
|
||||
if minLng < -180 {
|
||||
minLng = -180
|
||||
}
|
||||
maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
|
||||
if maxLng > 180 {
|
||||
maxLng = 180
|
||||
}
|
||||
minLat = latitude - radDeg(radiusMeters/EarthRadius)
|
||||
if minLat < -90 {
|
||||
minLat = -90
|
||||
}
|
||||
maxLat = latitude + radDeg(radiusMeters/EarthRadius)
|
||||
if maxLat > 90 {
|
||||
maxLat = 90
|
||||
}
|
||||
return
|
||||
minLat, maxLat, minLng, maxLng float64) {
|
||||
minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
|
||||
if minLng < -180 {
|
||||
minLng = -180
|
||||
}
|
||||
maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
|
||||
if maxLng > 180 {
|
||||
maxLng = 180
|
||||
}
|
||||
minLat = latitude - radDeg(radiusMeters/EarthRadius)
|
||||
if minLat < -90 {
|
||||
minLat = -90
|
||||
}
|
||||
maxLat = latitude + radDeg(radiusMeters/EarthRadius)
|
||||
if maxLat > 90 {
|
||||
maxLat = 90
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint {
|
||||
if radiusMeters == 0 {
|
||||
return defaultBitSize - 1
|
||||
}
|
||||
var precision uint = 1
|
||||
for radiusMeters < MercatorMax {
|
||||
radiusMeters *= 2
|
||||
precision++
|
||||
}
|
||||
/* Make sure range is included in most of the base cases. */
|
||||
precision -= 2
|
||||
if latitude > 66 || latitude < -66 {
|
||||
precision--
|
||||
if latitude > 80 || latitude < -80 {
|
||||
precision--
|
||||
}
|
||||
}
|
||||
if precision < 1 {
|
||||
precision = 1
|
||||
}
|
||||
if precision > 32 {
|
||||
precision = 32
|
||||
}
|
||||
return precision*2 - 1
|
||||
if radiusMeters == 0 {
|
||||
return defaultBitSize - 1
|
||||
}
|
||||
var precision uint = 1
|
||||
for radiusMeters < MercatorMax {
|
||||
radiusMeters *= 2
|
||||
precision++
|
||||
}
|
||||
/* Make sure range is included in most of the base cases. */
|
||||
precision -= 2
|
||||
if latitude > 66 || latitude < -66 {
|
||||
precision--
|
||||
if latitude > 80 || latitude < -80 {
|
||||
precision--
|
||||
}
|
||||
}
|
||||
if precision < 1 {
|
||||
precision = 1
|
||||
}
|
||||
if precision > 32 {
|
||||
precision = 32
|
||||
}
|
||||
return precision*2 - 1
|
||||
}
|
||||
|
||||
func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 {
|
||||
radLat1 := degRad(latitude1)
|
||||
radLat2 := degRad(latitude2)
|
||||
a := radLat1 - radLat2
|
||||
b := degRad(longitude1) - degRad(longitude2)
|
||||
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2) +
|
||||
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
|
||||
radLat1 := degRad(latitude1)
|
||||
radLat2 := degRad(latitude2)
|
||||
a := radLat1 - radLat2
|
||||
b := degRad(longitude1) - degRad(longitude2)
|
||||
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2)+
|
||||
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
|
||||
}
|
||||
|
||||
func ToRange(scope []byte, precision uint) [2]uint64 {
|
||||
lower := ToInt(scope)
|
||||
radius := uint64(1 << (64 - precision))
|
||||
upper := lower + radius
|
||||
return [2]uint64{lower, upper}
|
||||
lower := ToInt(scope)
|
||||
radius := uint64(1 << (64 - precision))
|
||||
upper := lower + radius
|
||||
return [2]uint64{lower, upper}
|
||||
}
|
||||
|
||||
func ensureValidLat(lat float64) float64 {
|
||||
if lat > 90 {
|
||||
return 90
|
||||
}
|
||||
if lat < -90 {
|
||||
return -90
|
||||
}
|
||||
return lat
|
||||
if lat > 90 {
|
||||
return 90
|
||||
}
|
||||
if lat < -90 {
|
||||
return -90
|
||||
}
|
||||
return lat
|
||||
}
|
||||
|
||||
func ensureValidLng(lng float64) float64 {
|
||||
if lng > 180 {
|
||||
return -360 + lng
|
||||
}
|
||||
if lng < -180 {
|
||||
return 360 + lng
|
||||
}
|
||||
return lng
|
||||
if lng > 180 {
|
||||
return -360 + lng
|
||||
}
|
||||
if lng < -180 {
|
||||
return 360 + lng
|
||||
}
|
||||
return lng
|
||||
}
|
||||
|
||||
func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 {
|
||||
precision := estimatePrecisionByRadius(radiusMeters, latitude)
|
||||
precision := estimatePrecisionByRadius(radiusMeters, latitude)
|
||||
|
||||
center, box := encode0(latitude, longitude, precision)
|
||||
height := box[0][1] - box[0][0]
|
||||
width := box[1][1] - box[1][0]
|
||||
centerLng := (box[0][1] + box[0][0]) / 2
|
||||
centerLat := (box[1][1] + box[1][0]) / 2
|
||||
maxLat := ensureValidLat(centerLat + height)
|
||||
minLat := ensureValidLat(centerLat - height)
|
||||
maxLng := ensureValidLng(centerLng + width)
|
||||
minLng := ensureValidLng(centerLng - width)
|
||||
center, box := encode0(latitude, longitude, precision)
|
||||
height := box[0][1] - box[0][0]
|
||||
width := box[1][1] - box[1][0]
|
||||
centerLng := (box[0][1] + box[0][0]) / 2
|
||||
centerLat := (box[1][1] + box[1][0]) / 2
|
||||
maxLat := ensureValidLat(centerLat + height)
|
||||
minLat := ensureValidLat(centerLat - height)
|
||||
maxLng := ensureValidLng(centerLng + width)
|
||||
minLng := ensureValidLng(centerLng - width)
|
||||
|
||||
var result [10][2]uint64
|
||||
leftUpper, _ := encode0(maxLat, minLng, precision)
|
||||
result[1] = ToRange(leftUpper, precision)
|
||||
upper, _ := encode0(maxLat, centerLng, precision)
|
||||
result[2] = ToRange(upper, precision)
|
||||
rightUpper, _ := encode0(maxLat, maxLng, precision)
|
||||
result[3] = ToRange(rightUpper, precision)
|
||||
left, _ := encode0(centerLat, minLng, precision)
|
||||
result[4] = ToRange(left, precision)
|
||||
result[5] = ToRange(center, precision)
|
||||
right, _ := encode0(centerLat, maxLng, precision)
|
||||
result[6] = ToRange(right, precision)
|
||||
leftDown, _ := encode0(minLat, minLng, precision)
|
||||
result[7] = ToRange(leftDown, precision)
|
||||
down, _ := encode0(minLat, centerLng, precision)
|
||||
result[8] = ToRange(down, precision)
|
||||
rightDown, _ := encode0(minLat, maxLng, precision)
|
||||
result[9] = ToRange(rightDown, precision)
|
||||
var result [10][2]uint64
|
||||
leftUpper, _ := encode0(maxLat, minLng, precision)
|
||||
result[1] = ToRange(leftUpper, precision)
|
||||
upper, _ := encode0(maxLat, centerLng, precision)
|
||||
result[2] = ToRange(upper, precision)
|
||||
rightUpper, _ := encode0(maxLat, maxLng, precision)
|
||||
result[3] = ToRange(rightUpper, precision)
|
||||
left, _ := encode0(centerLat, minLng, precision)
|
||||
result[4] = ToRange(left, precision)
|
||||
result[5] = ToRange(center, precision)
|
||||
right, _ := encode0(centerLat, maxLng, precision)
|
||||
result[6] = ToRange(right, precision)
|
||||
leftDown, _ := encode0(minLat, minLng, precision)
|
||||
result[7] = ToRange(leftDown, precision)
|
||||
down, _ := encode0(minLat, centerLng, precision)
|
||||
result[8] = ToRange(down, precision)
|
||||
rightDown, _ := encode0(minLat, maxLng, precision)
|
||||
result[9] = ToRange(rightDown, precision)
|
||||
|
||||
return result[1:]
|
||||
return result[1:]
|
||||
}
|
||||
|
@@ -1,21 +1,21 @@
|
||||
package gob
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
)
|
||||
|
||||
func Marshal(obj interface{}) ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func UnMarshal(data []byte, obj interface{}) error {
|
||||
dec := gob.NewDecoder(bytes.NewBuffer(data))
|
||||
return dec.Decode(obj)
|
||||
dec := gob.NewDecoder(bytes.NewBuffer(data))
|
||||
return dec.Decode(obj)
|
||||
}
|
||||
|
@@ -4,16 +4,14 @@ import "sync/atomic"
|
||||
|
||||
type AtomicBool uint32
|
||||
|
||||
func (b *AtomicBool)Get()bool {
|
||||
return atomic.LoadUint32((*uint32)(b)) != 0
|
||||
func (b *AtomicBool) Get() bool {
|
||||
return atomic.LoadUint32((*uint32)(b)) != 0
|
||||
}
|
||||
|
||||
func (b *AtomicBool)Set(v bool) {
|
||||
if v {
|
||||
atomic.StoreUint32((*uint32)(b), 1)
|
||||
} else {
|
||||
atomic.StoreUint32((*uint32)(b), 0)
|
||||
}
|
||||
func (b *AtomicBool) Set(v bool) {
|
||||
if v {
|
||||
atomic.StoreUint32((*uint32)(b), 1)
|
||||
} else {
|
||||
atomic.StoreUint32((*uint32)(b), 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@@ -1,38 +1,38 @@
|
||||
package wait
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Wait struct {
|
||||
wg sync.WaitGroup
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func (w *Wait)Add(delta int) {
|
||||
w.wg.Add(delta)
|
||||
func (w *Wait) Add(delta int) {
|
||||
w.wg.Add(delta)
|
||||
}
|
||||
|
||||
func (w *Wait)Done() {
|
||||
w.wg.Done()
|
||||
func (w *Wait) Done() {
|
||||
w.wg.Done()
|
||||
}
|
||||
|
||||
func (w *Wait)Wait() {
|
||||
w.wg.Wait()
|
||||
func (w *Wait) Wait() {
|
||||
w.wg.Wait()
|
||||
}
|
||||
|
||||
// return isTimeout
|
||||
func (w *Wait)WaitWithTimeout(timeout time.Duration)bool {
|
||||
c := make(chan bool)
|
||||
go func() {
|
||||
defer close(c)
|
||||
w.wg.Wait()
|
||||
c <- true
|
||||
}()
|
||||
select {
|
||||
case <-c:
|
||||
return false // completed normally
|
||||
case <-time.After(timeout):
|
||||
return true // timed out
|
||||
}
|
||||
}
|
||||
func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
|
||||
c := make(chan bool)
|
||||
go func() {
|
||||
defer close(c)
|
||||
w.wg.Wait()
|
||||
c <- true
|
||||
}()
|
||||
select {
|
||||
case <-c:
|
||||
return false // completed normally
|
||||
case <-time.After(timeout):
|
||||
return true // timed out
|
||||
}
|
||||
}
|
||||
|
@@ -1,95 +1,95 @@
|
||||
package wildcard
|
||||
|
||||
const (
|
||||
normal = iota
|
||||
all // *
|
||||
any // ?
|
||||
set_ // []
|
||||
normal = iota
|
||||
all // *
|
||||
any // ?
|
||||
set_ // []
|
||||
)
|
||||
|
||||
type item struct {
|
||||
character byte
|
||||
set map[byte]bool
|
||||
typeCode int
|
||||
character byte
|
||||
set map[byte]bool
|
||||
typeCode int
|
||||
}
|
||||
|
||||
func (i *item) contains(c byte) bool {
|
||||
_, ok := i.set[c]
|
||||
return ok
|
||||
_, ok := i.set[c]
|
||||
return ok
|
||||
}
|
||||
|
||||
type Pattern struct {
|
||||
items []*item
|
||||
items []*item
|
||||
}
|
||||
|
||||
func CompilePattern(src string) *Pattern {
|
||||
items := make([]*item, 0)
|
||||
escape := false
|
||||
inSet := false
|
||||
var set map[byte]bool
|
||||
for _, v := range src {
|
||||
c := byte(v)
|
||||
if escape {
|
||||
items = append(items, &item{typeCode: normal, character: c})
|
||||
escape = false
|
||||
} else if c == '*' {
|
||||
items = append(items, &item{typeCode: all})
|
||||
} else if c == '?' {
|
||||
items = append(items, &item{typeCode: any})
|
||||
} else if c == '\\' {
|
||||
escape = true
|
||||
} else if c == '[' {
|
||||
if !inSet {
|
||||
inSet = true
|
||||
set = make(map[byte]bool)
|
||||
} else {
|
||||
set[c] = true
|
||||
}
|
||||
} else if c == ']' {
|
||||
if inSet {
|
||||
inSet = false
|
||||
items = append(items, &item{typeCode: set_, set: set})
|
||||
} else {
|
||||
items = append(items, &item{typeCode: normal, character: c})
|
||||
}
|
||||
} else {
|
||||
if inSet {
|
||||
set[c] = true
|
||||
} else {
|
||||
items = append(items, &item{typeCode: normal, character: c})
|
||||
}
|
||||
}
|
||||
}
|
||||
return &Pattern{
|
||||
items: items,
|
||||
}
|
||||
items := make([]*item, 0)
|
||||
escape := false
|
||||
inSet := false
|
||||
var set map[byte]bool
|
||||
for _, v := range src {
|
||||
c := byte(v)
|
||||
if escape {
|
||||
items = append(items, &item{typeCode: normal, character: c})
|
||||
escape = false
|
||||
} else if c == '*' {
|
||||
items = append(items, &item{typeCode: all})
|
||||
} else if c == '?' {
|
||||
items = append(items, &item{typeCode: any})
|
||||
} else if c == '\\' {
|
||||
escape = true
|
||||
} else if c == '[' {
|
||||
if !inSet {
|
||||
inSet = true
|
||||
set = make(map[byte]bool)
|
||||
} else {
|
||||
set[c] = true
|
||||
}
|
||||
} else if c == ']' {
|
||||
if inSet {
|
||||
inSet = false
|
||||
items = append(items, &item{typeCode: set_, set: set})
|
||||
} else {
|
||||
items = append(items, &item{typeCode: normal, character: c})
|
||||
}
|
||||
} else {
|
||||
if inSet {
|
||||
set[c] = true
|
||||
} else {
|
||||
items = append(items, &item{typeCode: normal, character: c})
|
||||
}
|
||||
}
|
||||
}
|
||||
return &Pattern{
|
||||
items: items,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pattern) IsMatch(s string) bool {
|
||||
if len(p.items) == 0 {
|
||||
return len(s) == 0
|
||||
}
|
||||
m := len(s)
|
||||
n := len(p.items)
|
||||
table := make([][]bool, m+1)
|
||||
for i := 0; i < m+1; i++ {
|
||||
table[i] = make([]bool, n+1)
|
||||
}
|
||||
table[0][0] = true
|
||||
for j := 1; j < n+1; j++ {
|
||||
table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all
|
||||
}
|
||||
for i := 1; i < m+1; i++ {
|
||||
for j := 1; j < n+1; j++ {
|
||||
if p.items[j-1].typeCode == all {
|
||||
table[i][j] = table[i-1][j] || table[i][j-1]
|
||||
} else {
|
||||
table[i][j] = table[i-1][j-1] &&
|
||||
(p.items[j-1].typeCode == any ||
|
||||
(p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) ||
|
||||
(p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1])))
|
||||
}
|
||||
}
|
||||
}
|
||||
return table[m][n]
|
||||
if len(p.items) == 0 {
|
||||
return len(s) == 0
|
||||
}
|
||||
m := len(s)
|
||||
n := len(p.items)
|
||||
table := make([][]bool, m+1)
|
||||
for i := 0; i < m+1; i++ {
|
||||
table[i] = make([]bool, n+1)
|
||||
}
|
||||
table[0][0] = true
|
||||
for j := 1; j < n+1; j++ {
|
||||
table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all
|
||||
}
|
||||
for i := 1; i < m+1; i++ {
|
||||
for j := 1; j < n+1; j++ {
|
||||
if p.items[j-1].typeCode == all {
|
||||
table[i][j] = table[i-1][j] || table[i][j-1]
|
||||
} else {
|
||||
table[i][j] = table[i-1][j-1] &&
|
||||
(p.items[j-1].typeCode == any ||
|
||||
(p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) ||
|
||||
(p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1])))
|
||||
}
|
||||
}
|
||||
}
|
||||
return table[m][n]
|
||||
}
|
||||
|
@@ -3,73 +3,73 @@ package wildcard
|
||||
import "testing"
|
||||
|
||||
func TestWildCard(t *testing.T) {
|
||||
p := CompilePattern("a")
|
||||
if !p.IsMatch("a") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("b") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
p := CompilePattern("a")
|
||||
if !p.IsMatch("a") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("b") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
|
||||
// test '?'
|
||||
p = CompilePattern("a?")
|
||||
if !p.IsMatch("ab") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("a") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
if p.IsMatch("abb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
if p.IsMatch("bb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
// test '?'
|
||||
p = CompilePattern("a?")
|
||||
if !p.IsMatch("ab") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("a") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
if p.IsMatch("abb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
if p.IsMatch("bb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
|
||||
// test *
|
||||
p = CompilePattern("a*")
|
||||
if !p.IsMatch("ab") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if !p.IsMatch("a") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if !p.IsMatch("abb") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("bb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
// test *
|
||||
p = CompilePattern("a*")
|
||||
if !p.IsMatch("ab") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if !p.IsMatch("a") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if !p.IsMatch("abb") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("bb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
|
||||
// test []
|
||||
p = CompilePattern("a[ab[]")
|
||||
if !p.IsMatch("ab") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if !p.IsMatch("aa") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if !p.IsMatch("a[") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("abb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
if p.IsMatch("bb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
// test []
|
||||
p = CompilePattern("a[ab[]")
|
||||
if !p.IsMatch("ab") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if !p.IsMatch("aa") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if !p.IsMatch("a[") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("abb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
if p.IsMatch("bb") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
|
||||
// test escape
|
||||
p = CompilePattern("\\\\") // pattern: \\
|
||||
if !p.IsMatch("\\") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
// test escape
|
||||
p = CompilePattern("\\\\") // pattern: \\
|
||||
if !p.IsMatch("\\") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
|
||||
p = CompilePattern("\\*")
|
||||
if !p.IsMatch("*") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("a") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
p = CompilePattern("\\*")
|
||||
if !p.IsMatch("*") {
|
||||
t.Error("expect true actually false")
|
||||
}
|
||||
if p.IsMatch("a") {
|
||||
t.Error("expect false actually true")
|
||||
}
|
||||
}
|
||||
|
@@ -1,20 +1,20 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||
"github.com/HDT3213/godis/src/datastruct/lock"
|
||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||
"github.com/HDT3213/godis/src/datastruct/lock"
|
||||
)
|
||||
|
||||
type Hub struct {
|
||||
// channel -> list(*Client)
|
||||
subs dict.Dict
|
||||
// lock channel
|
||||
subsLocker *lock.Locks
|
||||
// channel -> list(*Client)
|
||||
subs dict.Dict
|
||||
// lock channel
|
||||
subsLocker *lock.Locks
|
||||
}
|
||||
|
||||
func MakeHub() *Hub {
|
||||
return &Hub{
|
||||
subs: dict.MakeConcurrent(4),
|
||||
subsLocker: lock.Make(16),
|
||||
}
|
||||
return &Hub{
|
||||
subs: dict.MakeConcurrent(4),
|
||||
subsLocker: lock.Make(16),
|
||||
}
|
||||
}
|
||||
|
@@ -1,23 +1,23 @@
|
||||
package pubsub
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/datastruct/list"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
"github.com/HDT3213/godis/src/datastruct/list"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
_subscribe = "subscribe"
|
||||
_unsubscribe = "unsubscribe"
|
||||
messageBytes = []byte("message")
|
||||
unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n")
|
||||
_subscribe = "subscribe"
|
||||
_unsubscribe = "unsubscribe"
|
||||
messageBytes = []byte("message")
|
||||
unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n")
|
||||
)
|
||||
|
||||
func makeMsg(t string, channel string, code int64) []byte {
|
||||
return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF +
|
||||
"$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF +
|
||||
":" + strconv.FormatInt(code, 10) + reply.CRLF)
|
||||
return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF +
|
||||
"$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF +
|
||||
":" + strconv.FormatInt(code, 10) + reply.CRLF)
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -25,22 +25,22 @@ func makeMsg(t string, channel string, code int64) []byte {
|
||||
* return: is new subscribed
|
||||
*/
|
||||
func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
||||
client.SubsChannel(channel)
|
||||
client.SubsChannel(channel)
|
||||
|
||||
// add into hub.subs
|
||||
raw, ok := hub.subs.Get(channel)
|
||||
var subscribers *list.LinkedList
|
||||
if ok {
|
||||
subscribers, _ = raw.(*list.LinkedList)
|
||||
} else {
|
||||
subscribers = list.Make()
|
||||
hub.subs.Put(channel, subscribers)
|
||||
}
|
||||
if subscribers.Contains(client) {
|
||||
return false
|
||||
}
|
||||
subscribers.Add(client)
|
||||
return true
|
||||
// add into hub.subs
|
||||
raw, ok := hub.subs.Get(channel)
|
||||
var subscribers *list.LinkedList
|
||||
if ok {
|
||||
subscribers, _ = raw.(*list.LinkedList)
|
||||
} else {
|
||||
subscribers = list.Make()
|
||||
hub.subs.Put(channel, subscribers)
|
||||
}
|
||||
if subscribers.Contains(client) {
|
||||
return false
|
||||
}
|
||||
subscribers.Add(client)
|
||||
return true
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -48,102 +48,102 @@ func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
||||
* return: is actually un-subscribe
|
||||
*/
|
||||
func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
||||
client.UnSubsChannel(channel)
|
||||
client.UnSubsChannel(channel)
|
||||
|
||||
// remove from hub.subs
|
||||
raw, ok := hub.subs.Get(channel)
|
||||
if ok {
|
||||
subscribers, _ := raw.(*list.LinkedList)
|
||||
subscribers.RemoveAllByVal(client)
|
||||
// remove from hub.subs
|
||||
raw, ok := hub.subs.Get(channel)
|
||||
if ok {
|
||||
subscribers, _ := raw.(*list.LinkedList)
|
||||
subscribers.RemoveAllByVal(client)
|
||||
|
||||
if subscribers.Len() == 0 {
|
||||
// clean
|
||||
hub.subs.Remove(channel)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
if subscribers.Len() == 0 {
|
||||
// clean
|
||||
hub.subs.Remove(channel)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
|
||||
channels := make([]string, len(args))
|
||||
for i, b := range args {
|
||||
channels[i] = string(b)
|
||||
}
|
||||
channels := make([]string, len(args))
|
||||
for i, b := range args {
|
||||
channels[i] = string(b)
|
||||
}
|
||||
|
||||
hub.subsLocker.Locks(channels...)
|
||||
defer hub.subsLocker.UnLocks(channels...)
|
||||
hub.subsLocker.Locks(channels...)
|
||||
defer hub.subsLocker.UnLocks(channels...)
|
||||
|
||||
for _, channel := range channels {
|
||||
if subscribe0(hub, channel, c) {
|
||||
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
|
||||
}
|
||||
}
|
||||
return &reply.NoReply{}
|
||||
for _, channel := range channels {
|
||||
if subscribe0(hub, channel, c) {
|
||||
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
|
||||
}
|
||||
}
|
||||
return &reply.NoReply{}
|
||||
}
|
||||
|
||||
func UnsubscribeAll(hub *Hub, c redis.Connection) {
|
||||
channels := c.GetChannels()
|
||||
channels := c.GetChannels()
|
||||
|
||||
hub.subsLocker.Locks(channels...)
|
||||
defer hub.subsLocker.UnLocks(channels...)
|
||||
hub.subsLocker.Locks(channels...)
|
||||
defer hub.subsLocker.UnLocks(channels...)
|
||||
|
||||
for _, channel := range channels {
|
||||
unsubscribe0(hub, channel, c)
|
||||
}
|
||||
for _, channel := range channels {
|
||||
unsubscribe0(hub, channel, c)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
|
||||
var channels []string
|
||||
if len(args) > 0 {
|
||||
channels = make([]string, len(args))
|
||||
for i, b := range args {
|
||||
channels[i] = string(b)
|
||||
}
|
||||
} else {
|
||||
channels = c.GetChannels()
|
||||
}
|
||||
var channels []string
|
||||
if len(args) > 0 {
|
||||
channels = make([]string, len(args))
|
||||
for i, b := range args {
|
||||
channels[i] = string(b)
|
||||
}
|
||||
} else {
|
||||
channels = c.GetChannels()
|
||||
}
|
||||
|
||||
db.subsLocker.Locks(channels...)
|
||||
defer db.subsLocker.UnLocks(channels...)
|
||||
db.subsLocker.Locks(channels...)
|
||||
defer db.subsLocker.UnLocks(channels...)
|
||||
|
||||
if len(channels) == 0 {
|
||||
_ = c.Write(unSubscribeNothing)
|
||||
return &reply.NoReply{}
|
||||
}
|
||||
if len(channels) == 0 {
|
||||
_ = c.Write(unSubscribeNothing)
|
||||
return &reply.NoReply{}
|
||||
}
|
||||
|
||||
for _, channel := range channels {
|
||||
if unsubscribe0(db, channel, c) {
|
||||
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
|
||||
}
|
||||
}
|
||||
return &reply.NoReply{}
|
||||
for _, channel := range channels {
|
||||
if unsubscribe0(db, channel, c) {
|
||||
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
|
||||
}
|
||||
}
|
||||
return &reply.NoReply{}
|
||||
}
|
||||
|
||||
func Publish(hub *Hub, args [][]byte) redis.Reply {
|
||||
if len(args) != 2 {
|
||||
return &reply.ArgNumErrReply{Cmd: "publish"}
|
||||
}
|
||||
channel := string(args[0])
|
||||
message := args[1]
|
||||
if len(args) != 2 {
|
||||
return &reply.ArgNumErrReply{Cmd: "publish"}
|
||||
}
|
||||
channel := string(args[0])
|
||||
message := args[1]
|
||||
|
||||
hub.subsLocker.Lock(channel)
|
||||
defer hub.subsLocker.UnLock(channel)
|
||||
hub.subsLocker.Lock(channel)
|
||||
defer hub.subsLocker.UnLock(channel)
|
||||
|
||||
raw, ok := hub.subs.Get(channel)
|
||||
if !ok {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
subscribers, _ := raw.(*list.LinkedList)
|
||||
subscribers.ForEach(func(i int, c interface{}) bool {
|
||||
client, _ := c.(redis.Connection)
|
||||
replyArgs := make([][]byte, 3)
|
||||
replyArgs[0] = messageBytes
|
||||
replyArgs[1] = []byte(channel)
|
||||
replyArgs[2] = message
|
||||
_ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes())
|
||||
return true
|
||||
})
|
||||
return reply.MakeIntReply(int64(subscribers.Len()))
|
||||
raw, ok := hub.subs.Get(channel)
|
||||
if !ok {
|
||||
return reply.MakeIntReply(0)
|
||||
}
|
||||
subscribers, _ := raw.(*list.LinkedList)
|
||||
subscribers.ForEach(func(i int, c interface{}) bool {
|
||||
client, _ := c.(redis.Connection)
|
||||
replyArgs := make([][]byte, 3)
|
||||
replyArgs[0] = messageBytes
|
||||
replyArgs[1] = []byte(channel)
|
||||
replyArgs[2] = message
|
||||
_ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes())
|
||||
return true
|
||||
})
|
||||
return reply.MakeIntReply(int64(subscribers.Len()))
|
||||
}
|
||||
|
@@ -1,322 +1,321 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
conn net.Conn
|
||||
sendingReqs chan *Request // waiting sending
|
||||
waitingReqs chan *Request // waiting response
|
||||
ticker *time.Ticker
|
||||
addr string
|
||||
conn net.Conn
|
||||
sendingReqs chan *Request // waiting sending
|
||||
waitingReqs chan *Request // waiting response
|
||||
ticker *time.Ticker
|
||||
addr string
|
||||
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
writing *sync.WaitGroup
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
writing *sync.WaitGroup
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
id uint64
|
||||
args [][]byte
|
||||
reply redis.Reply
|
||||
heartbeat bool
|
||||
waiting *wait.Wait
|
||||
err error
|
||||
id uint64
|
||||
args [][]byte
|
||||
reply redis.Reply
|
||||
heartbeat bool
|
||||
waiting *wait.Wait
|
||||
err error
|
||||
}
|
||||
|
||||
const (
|
||||
chanSize = 256
|
||||
maxWait = 3 * time.Second
|
||||
chanSize = 256
|
||||
maxWait = 3 * time.Second
|
||||
)
|
||||
|
||||
func MakeClient(addr string) (*Client, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Client{
|
||||
addr: addr,
|
||||
conn: conn,
|
||||
sendingReqs: make(chan *Request, chanSize),
|
||||
waitingReqs: make(chan *Request, chanSize),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
writing: &sync.WaitGroup{},
|
||||
}, nil
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Client{
|
||||
addr: addr,
|
||||
conn: conn,
|
||||
sendingReqs: make(chan *Request, chanSize),
|
||||
waitingReqs: make(chan *Request, chanSize),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
writing: &sync.WaitGroup{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (client *Client) Start() {
|
||||
client.ticker = time.NewTicker(10 * time.Second)
|
||||
go client.handleWrite()
|
||||
go func() {
|
||||
err := client.handleRead()
|
||||
logger.Warn(err)
|
||||
}()
|
||||
go client.heartbeat()
|
||||
client.ticker = time.NewTicker(10 * time.Second)
|
||||
go client.handleWrite()
|
||||
go func() {
|
||||
err := client.handleRead()
|
||||
logger.Warn(err)
|
||||
}()
|
||||
go client.heartbeat()
|
||||
}
|
||||
|
||||
func (client *Client) Close() {
|
||||
// stop new request
|
||||
close(client.sendingReqs)
|
||||
// stop new request
|
||||
close(client.sendingReqs)
|
||||
|
||||
// wait stop process
|
||||
client.writing.Wait()
|
||||
// wait stop process
|
||||
client.writing.Wait()
|
||||
|
||||
// clean
|
||||
client.cancelFunc()
|
||||
_ = client.conn.Close()
|
||||
close(client.waitingReqs)
|
||||
// clean
|
||||
client.cancelFunc()
|
||||
_ = client.conn.Close()
|
||||
close(client.waitingReqs)
|
||||
}
|
||||
|
||||
func (client *Client) handleConnectionError(err error) error {
|
||||
err1 := client.conn.Close()
|
||||
if err1 != nil {
|
||||
if opErr, ok := err1.(*net.OpError); ok {
|
||||
if opErr.Err.Error() != "use of closed network connection" {
|
||||
return err1
|
||||
}
|
||||
} else {
|
||||
return err1
|
||||
}
|
||||
}
|
||||
conn, err1 := net.Dial("tcp", client.addr)
|
||||
if err1 != nil {
|
||||
logger.Error(err1)
|
||||
return err1
|
||||
}
|
||||
client.conn = conn
|
||||
go func() {
|
||||
_ = client.handleRead()
|
||||
}()
|
||||
return nil
|
||||
err1 := client.conn.Close()
|
||||
if err1 != nil {
|
||||
if opErr, ok := err1.(*net.OpError); ok {
|
||||
if opErr.Err.Error() != "use of closed network connection" {
|
||||
return err1
|
||||
}
|
||||
} else {
|
||||
return err1
|
||||
}
|
||||
}
|
||||
conn, err1 := net.Dial("tcp", client.addr)
|
||||
if err1 != nil {
|
||||
logger.Error(err1)
|
||||
return err1
|
||||
}
|
||||
client.conn = conn
|
||||
go func() {
|
||||
_ = client.handleRead()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (client *Client) heartbeat() {
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-client.ticker.C:
|
||||
client.sendingReqs <- &Request{
|
||||
args: [][]byte{[]byte("PING")},
|
||||
heartbeat: true,
|
||||
}
|
||||
case <-client.ctx.Done():
|
||||
break loop
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-client.ticker.C:
|
||||
client.sendingReqs <- &Request{
|
||||
args: [][]byte{[]byte("PING")},
|
||||
heartbeat: true,
|
||||
}
|
||||
case <-client.ctx.Done():
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) handleWrite() {
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case req := <-client.sendingReqs:
|
||||
client.writing.Add(1)
|
||||
client.doRequest(req)
|
||||
case <-client.ctx.Done():
|
||||
break loop
|
||||
}
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case req := <-client.sendingReqs:
|
||||
client.writing.Add(1)
|
||||
client.doRequest(req)
|
||||
case <-client.ctx.Done():
|
||||
break loop
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// todo: wait with timeout
|
||||
func (client *Client) Send(args [][]byte) redis.Reply {
|
||||
request := &Request{
|
||||
args: args,
|
||||
heartbeat: false,
|
||||
waiting: &wait.Wait{},
|
||||
}
|
||||
request.waiting.Add(1)
|
||||
client.sendingReqs <- request
|
||||
timeout := request.waiting.WaitWithTimeout(maxWait)
|
||||
if timeout {
|
||||
return reply.MakeErrReply("server time out")
|
||||
}
|
||||
if request.err != nil {
|
||||
return reply.MakeErrReply("request failed")
|
||||
}
|
||||
return request.reply
|
||||
request := &Request{
|
||||
args: args,
|
||||
heartbeat: false,
|
||||
waiting: &wait.Wait{},
|
||||
}
|
||||
request.waiting.Add(1)
|
||||
client.sendingReqs <- request
|
||||
timeout := request.waiting.WaitWithTimeout(maxWait)
|
||||
if timeout {
|
||||
return reply.MakeErrReply("server time out")
|
||||
}
|
||||
if request.err != nil {
|
||||
return reply.MakeErrReply("request failed")
|
||||
}
|
||||
return request.reply
|
||||
}
|
||||
|
||||
func (client *Client) doRequest(req *Request) {
|
||||
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
|
||||
_, err := client.conn.Write(bytes)
|
||||
i := 0
|
||||
for err != nil && i < 3 {
|
||||
err = client.handleConnectionError(err)
|
||||
if err == nil {
|
||||
_, err = client.conn.Write(bytes)
|
||||
}
|
||||
i++
|
||||
}
|
||||
if err == nil {
|
||||
client.waitingReqs <- req
|
||||
} else {
|
||||
req.err = err
|
||||
req.waiting.Done()
|
||||
client.writing.Done()
|
||||
}
|
||||
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
|
||||
_, err := client.conn.Write(bytes)
|
||||
i := 0
|
||||
for err != nil && i < 3 {
|
||||
err = client.handleConnectionError(err)
|
||||
if err == nil {
|
||||
_, err = client.conn.Write(bytes)
|
||||
}
|
||||
i++
|
||||
}
|
||||
if err == nil {
|
||||
client.waitingReqs <- req
|
||||
} else {
|
||||
req.err = err
|
||||
req.waiting.Done()
|
||||
client.writing.Done()
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) finishRequest(reply redis.Reply) {
|
||||
request := <-client.waitingReqs
|
||||
request.reply = reply
|
||||
if request.waiting != nil {
|
||||
request.waiting.Done()
|
||||
}
|
||||
client.writing.Done()
|
||||
request := <-client.waitingReqs
|
||||
request.reply = reply
|
||||
if request.waiting != nil {
|
||||
request.waiting.Done()
|
||||
}
|
||||
client.writing.Done()
|
||||
}
|
||||
|
||||
func (client *Client) handleRead() error {
|
||||
reader := bufio.NewReader(client.conn)
|
||||
downloading := false
|
||||
expectedArgsCount := 0
|
||||
receivedCount := 0
|
||||
msgType := byte(0) // first char of msg
|
||||
var args [][]byte
|
||||
var fixedLen int64 = 0
|
||||
var err error
|
||||
var msg []byte
|
||||
for {
|
||||
// read line
|
||||
if fixedLen == 0 { // read normal line
|
||||
msg, err = reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
reader := bufio.NewReader(client.conn)
|
||||
downloading := false
|
||||
expectedArgsCount := 0
|
||||
receivedCount := 0
|
||||
msgType := byte(0) // first char of msg
|
||||
var args [][]byte
|
||||
var fixedLen int64 = 0
|
||||
var err error
|
||||
var msg []byte
|
||||
for {
|
||||
// read line
|
||||
if fixedLen == 0 { // read normal line
|
||||
msg, err = reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
|
||||
return errors.New("connection closed")
|
||||
}
|
||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else { // read bulk line (binary safe)
|
||||
msg = make([]byte, fixedLen+2)
|
||||
_, err = io.ReadFull(reader, msg)
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
return errors.New("connection closed")
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(msg) == 0 ||
|
||||
msg[len(msg)-2] != '\r' ||
|
||||
msg[len(msg)-1] != '\n' {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
fixedLen = 0
|
||||
}
|
||||
return errors.New("connection closed")
|
||||
}
|
||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else { // read bulk line (binary safe)
|
||||
msg = make([]byte, fixedLen+2)
|
||||
_, err = io.ReadFull(reader, msg)
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
return errors.New("connection closed")
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(msg) == 0 ||
|
||||
msg[len(msg)-2] != '\r' ||
|
||||
msg[len(msg)-1] != '\n' {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
fixedLen = 0
|
||||
}
|
||||
|
||||
// parse line
|
||||
if !downloading {
|
||||
// receive new response
|
||||
if msg[0] == '*' { // multi bulk response
|
||||
// bulk multi msg
|
||||
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||
if err != nil {
|
||||
return errors.New("protocol error: " + err.Error())
|
||||
}
|
||||
if expectedLine == 0 {
|
||||
client.finishRequest(&reply.EmptyMultiBulkReply{})
|
||||
} else if expectedLine > 0 {
|
||||
msgType = msg[0]
|
||||
downloading = true
|
||||
expectedArgsCount = int(expectedLine)
|
||||
receivedCount = 0
|
||||
args = make([][]byte, expectedLine)
|
||||
} else {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else if msg[0] == '$' { // bulk response
|
||||
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fixedLen == -1 { // null bulk
|
||||
client.finishRequest(&reply.NullBulkReply{})
|
||||
fixedLen = 0
|
||||
} else if fixedLen > 0 {
|
||||
msgType = msg[0]
|
||||
downloading = true
|
||||
expectedArgsCount = 1
|
||||
receivedCount = 0
|
||||
args = make([][]byte, 1)
|
||||
} else {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else { // single line response
|
||||
str := strings.TrimSuffix(string(msg), "\n")
|
||||
str = strings.TrimSuffix(str, "\r")
|
||||
var result redis.Reply
|
||||
switch msg[0] {
|
||||
case '+':
|
||||
result = reply.MakeStatusReply(str[1:])
|
||||
case '-':
|
||||
result = reply.MakeErrReply(str[1:])
|
||||
case ':':
|
||||
val, err := strconv.ParseInt(str[1:], 10, 64)
|
||||
if err != nil {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
result = reply.MakeIntReply(val)
|
||||
}
|
||||
client.finishRequest(result)
|
||||
}
|
||||
} else {
|
||||
// receive following part of a request
|
||||
line := msg[0 : len(msg)-2]
|
||||
if line[0] == '$' {
|
||||
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fixedLen <= 0 { // null bulk in multi bulks
|
||||
args[receivedCount] = []byte{}
|
||||
receivedCount++
|
||||
fixedLen = 0
|
||||
}
|
||||
} else {
|
||||
args[receivedCount] = line
|
||||
receivedCount++
|
||||
}
|
||||
// parse line
|
||||
if !downloading {
|
||||
// receive new response
|
||||
if msg[0] == '*' { // multi bulk response
|
||||
// bulk multi msg
|
||||
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||
if err != nil {
|
||||
return errors.New("protocol error: " + err.Error())
|
||||
}
|
||||
if expectedLine == 0 {
|
||||
client.finishRequest(&reply.EmptyMultiBulkReply{})
|
||||
} else if expectedLine > 0 {
|
||||
msgType = msg[0]
|
||||
downloading = true
|
||||
expectedArgsCount = int(expectedLine)
|
||||
receivedCount = 0
|
||||
args = make([][]byte, expectedLine)
|
||||
} else {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else if msg[0] == '$' { // bulk response
|
||||
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fixedLen == -1 { // null bulk
|
||||
client.finishRequest(&reply.NullBulkReply{})
|
||||
fixedLen = 0
|
||||
} else if fixedLen > 0 {
|
||||
msgType = msg[0]
|
||||
downloading = true
|
||||
expectedArgsCount = 1
|
||||
receivedCount = 0
|
||||
args = make([][]byte, 1)
|
||||
} else {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else { // single line response
|
||||
str := strings.TrimSuffix(string(msg), "\n")
|
||||
str = strings.TrimSuffix(str, "\r")
|
||||
var result redis.Reply
|
||||
switch msg[0] {
|
||||
case '+':
|
||||
result = reply.MakeStatusReply(str[1:])
|
||||
case '-':
|
||||
result = reply.MakeErrReply(str[1:])
|
||||
case ':':
|
||||
val, err := strconv.ParseInt(str[1:], 10, 64)
|
||||
if err != nil {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
result = reply.MakeIntReply(val)
|
||||
}
|
||||
client.finishRequest(result)
|
||||
}
|
||||
} else {
|
||||
// receive following part of a request
|
||||
line := msg[0 : len(msg)-2]
|
||||
if line[0] == '$' {
|
||||
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fixedLen <= 0 { // null bulk in multi bulks
|
||||
args[receivedCount] = []byte{}
|
||||
receivedCount++
|
||||
fixedLen = 0
|
||||
}
|
||||
} else {
|
||||
args[receivedCount] = line
|
||||
receivedCount++
|
||||
}
|
||||
|
||||
// if sending finished
|
||||
if receivedCount == expectedArgsCount {
|
||||
downloading = false // finish downloading progress
|
||||
// if sending finished
|
||||
if receivedCount == expectedArgsCount {
|
||||
downloading = false // finish downloading progress
|
||||
|
||||
if msgType == '*' {
|
||||
reply := reply.MakeMultiBulkReply(args)
|
||||
client.finishRequest(reply)
|
||||
} else if msgType == '$' {
|
||||
reply := reply.MakeBulkReply(args[0])
|
||||
client.finishRequest(reply)
|
||||
}
|
||||
if msgType == '*' {
|
||||
reply := reply.MakeMultiBulkReply(args)
|
||||
client.finishRequest(reply)
|
||||
} else if msgType == '$' {
|
||||
reply := reply.MakeBulkReply(args[0])
|
||||
client.finishRequest(reply)
|
||||
}
|
||||
|
||||
|
||||
// finish reply
|
||||
expectedArgsCount = 0
|
||||
receivedCount = 0
|
||||
args = nil
|
||||
msgType = byte(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
// finish reply
|
||||
expectedArgsCount = 0
|
||||
receivedCount = 0
|
||||
args = nil
|
||||
msgType = byte(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,104 +1,104 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"testing"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClient(t *testing.T) {
|
||||
logger.Setup(&logger.Settings{
|
||||
Path: "logs",
|
||||
Name: "godis",
|
||||
Ext: ".log",
|
||||
TimeFormat: "2006-01-02",
|
||||
})
|
||||
client, err := MakeClient("localhost:6379")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
client.Start()
|
||||
logger.Setup(&logger.Settings{
|
||||
Path: "logs",
|
||||
Name: "godis",
|
||||
Ext: ".log",
|
||||
TimeFormat: "2006-01-02",
|
||||
})
|
||||
client, err := MakeClient("localhost:6379")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
client.Start()
|
||||
|
||||
result := client.Send([][]byte{
|
||||
[]byte("PING"),
|
||||
})
|
||||
if statusRet, ok := result.(*reply.StatusReply); ok {
|
||||
if statusRet.Status != "PONG" {
|
||||
t.Error("`ping` failed, result: " + statusRet.Status)
|
||||
}
|
||||
}
|
||||
result := client.Send([][]byte{
|
||||
[]byte("PING"),
|
||||
})
|
||||
if statusRet, ok := result.(*reply.StatusReply); ok {
|
||||
if statusRet.Status != "PONG" {
|
||||
t.Error("`ping` failed, result: " + statusRet.Status)
|
||||
}
|
||||
}
|
||||
|
||||
result = client.Send([][]byte{
|
||||
[]byte("SET"),
|
||||
[]byte("a"),
|
||||
[]byte("a"),
|
||||
})
|
||||
if statusRet, ok := result.(*reply.StatusReply); ok {
|
||||
if statusRet.Status != "OK" {
|
||||
t.Error("`set` failed, result: " + statusRet.Status)
|
||||
}
|
||||
}
|
||||
result = client.Send([][]byte{
|
||||
[]byte("SET"),
|
||||
[]byte("a"),
|
||||
[]byte("a"),
|
||||
})
|
||||
if statusRet, ok := result.(*reply.StatusReply); ok {
|
||||
if statusRet.Status != "OK" {
|
||||
t.Error("`set` failed, result: " + statusRet.Status)
|
||||
}
|
||||
}
|
||||
|
||||
result = client.Send([][]byte{
|
||||
[]byte("GET"),
|
||||
[]byte("a"),
|
||||
})
|
||||
if bulkRet, ok := result.(*reply.BulkReply); ok {
|
||||
if string(bulkRet.Arg) != "a" {
|
||||
t.Error("`get` failed, result: " + string(bulkRet.Arg))
|
||||
}
|
||||
}
|
||||
result = client.Send([][]byte{
|
||||
[]byte("GET"),
|
||||
[]byte("a"),
|
||||
})
|
||||
if bulkRet, ok := result.(*reply.BulkReply); ok {
|
||||
if string(bulkRet.Arg) != "a" {
|
||||
t.Error("`get` failed, result: " + string(bulkRet.Arg))
|
||||
}
|
||||
}
|
||||
|
||||
result = client.Send([][]byte{
|
||||
[]byte("DEL"),
|
||||
[]byte("a"),
|
||||
})
|
||||
if intRet, ok := result.(*reply.IntReply); ok {
|
||||
if intRet.Code != 1 {
|
||||
t.Error("`del` failed, result: " + string(intRet.Code))
|
||||
}
|
||||
}
|
||||
result = client.Send([][]byte{
|
||||
[]byte("DEL"),
|
||||
[]byte("a"),
|
||||
})
|
||||
if intRet, ok := result.(*reply.IntReply); ok {
|
||||
if intRet.Code != 1 {
|
||||
t.Error("`del` failed, result: " + string(intRet.Code))
|
||||
}
|
||||
}
|
||||
|
||||
result = client.Send([][]byte{
|
||||
[]byte("GET"),
|
||||
[]byte("a"),
|
||||
})
|
||||
if _, ok := result.(*reply.NullBulkReply); !ok {
|
||||
t.Error("`get` failed, result: " + string(result.ToBytes()))
|
||||
}
|
||||
result = client.Send([][]byte{
|
||||
[]byte("GET"),
|
||||
[]byte("a"),
|
||||
})
|
||||
if _, ok := result.(*reply.NullBulkReply); !ok {
|
||||
t.Error("`get` failed, result: " + string(result.ToBytes()))
|
||||
}
|
||||
|
||||
result = client.Send([][]byte{
|
||||
[]byte("DEL"),
|
||||
[]byte("arr"),
|
||||
})
|
||||
result = client.Send([][]byte{
|
||||
[]byte("DEL"),
|
||||
[]byte("arr"),
|
||||
})
|
||||
|
||||
result = client.Send([][]byte{
|
||||
[]byte("RPUSH"),
|
||||
[]byte("arr"),
|
||||
[]byte("1"),
|
||||
[]byte("2"),
|
||||
[]byte("c"),
|
||||
})
|
||||
if intRet, ok := result.(*reply.IntReply); ok {
|
||||
if intRet.Code != 3 {
|
||||
t.Error("`rpush` failed, result: " + string(intRet.Code))
|
||||
}
|
||||
}
|
||||
result = client.Send([][]byte{
|
||||
[]byte("RPUSH"),
|
||||
[]byte("arr"),
|
||||
[]byte("1"),
|
||||
[]byte("2"),
|
||||
[]byte("c"),
|
||||
})
|
||||
if intRet, ok := result.(*reply.IntReply); ok {
|
||||
if intRet.Code != 3 {
|
||||
t.Error("`rpush` failed, result: " + string(intRet.Code))
|
||||
}
|
||||
}
|
||||
|
||||
result = client.Send([][]byte{
|
||||
[]byte("LRANGE"),
|
||||
[]byte("arr"),
|
||||
[]byte("0"),
|
||||
[]byte("-1"),
|
||||
})
|
||||
if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok {
|
||||
if len(multiBulkRet.Args) != 3 ||
|
||||
string(multiBulkRet.Args[0]) != "1" ||
|
||||
string(multiBulkRet.Args[1]) != "2" ||
|
||||
string(multiBulkRet.Args[2]) != "c" {
|
||||
t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes()))
|
||||
}
|
||||
}
|
||||
result = client.Send([][]byte{
|
||||
[]byte("LRANGE"),
|
||||
[]byte("arr"),
|
||||
[]byte("0"),
|
||||
[]byte("-1"),
|
||||
})
|
||||
if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok {
|
||||
if len(multiBulkRet.Args) != 3 ||
|
||||
string(multiBulkRet.Args[0]) != "1" ||
|
||||
string(multiBulkRet.Args[1]) != "2" ||
|
||||
string(multiBulkRet.Args[2]) != "c" {
|
||||
t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes()))
|
||||
}
|
||||
}
|
||||
|
||||
client.Close()
|
||||
client.Close()
|
||||
}
|
||||
|
@@ -1,42 +1,42 @@
|
||||
package reply
|
||||
|
||||
type PongReply struct {}
|
||||
type PongReply struct{}
|
||||
|
||||
var PongBytes = []byte("+PONG\r\n")
|
||||
|
||||
func (r *PongReply)ToBytes()[]byte {
|
||||
return PongBytes
|
||||
func (r *PongReply) ToBytes() []byte {
|
||||
return PongBytes
|
||||
}
|
||||
|
||||
type OkReply struct {}
|
||||
type OkReply struct{}
|
||||
|
||||
var okBytes = []byte("+OK\r\n")
|
||||
|
||||
func (r *OkReply)ToBytes()[]byte {
|
||||
return okBytes
|
||||
func (r *OkReply) ToBytes() []byte {
|
||||
return okBytes
|
||||
}
|
||||
|
||||
var nullBulkBytes = []byte("$-1\r\n")
|
||||
|
||||
type NullBulkReply struct {}
|
||||
type NullBulkReply struct{}
|
||||
|
||||
func (r *NullBulkReply)ToBytes()[]byte {
|
||||
return nullBulkBytes
|
||||
func (r *NullBulkReply) ToBytes() []byte {
|
||||
return nullBulkBytes
|
||||
}
|
||||
|
||||
var emptyMultiBulkBytes = []byte("*0\r\n")
|
||||
|
||||
type EmptyMultiBulkReply struct {}
|
||||
type EmptyMultiBulkReply struct{}
|
||||
|
||||
func (r *EmptyMultiBulkReply)ToBytes()[]byte {
|
||||
return emptyMultiBulkBytes
|
||||
func (r *EmptyMultiBulkReply) ToBytes() []byte {
|
||||
return emptyMultiBulkBytes
|
||||
}
|
||||
|
||||
// reply nothing, for commands like subscribe
|
||||
type NoReply struct {}
|
||||
type NoReply struct{}
|
||||
|
||||
var NoBytes = []byte("")
|
||||
|
||||
func (r *NoReply)ToBytes()[]byte {
|
||||
return NoBytes
|
||||
}
|
||||
func (r *NoReply) ToBytes() []byte {
|
||||
return NoBytes
|
||||
}
|
||||
|
@@ -1,67 +1,67 @@
|
||||
package reply
|
||||
|
||||
// UnknownErr
|
||||
type UnknownErrReply struct {}
|
||||
type UnknownErrReply struct{}
|
||||
|
||||
var unknownErrBytes = []byte("-Err unknown\r\n")
|
||||
|
||||
func (r *UnknownErrReply)ToBytes()[]byte {
|
||||
return unknownErrBytes
|
||||
func (r *UnknownErrReply) ToBytes() []byte {
|
||||
return unknownErrBytes
|
||||
}
|
||||
|
||||
func (r *UnknownErrReply) Error()string {
|
||||
return "Err unknown"
|
||||
func (r *UnknownErrReply) Error() string {
|
||||
return "Err unknown"
|
||||
}
|
||||
|
||||
// ArgNumErr
|
||||
type ArgNumErrReply struct {
|
||||
Cmd string
|
||||
Cmd string
|
||||
}
|
||||
|
||||
func (r *ArgNumErrReply)ToBytes()[]byte {
|
||||
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
|
||||
func (r *ArgNumErrReply) ToBytes() []byte {
|
||||
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
|
||||
}
|
||||
|
||||
func (r *ArgNumErrReply) Error()string {
|
||||
return "ERR wrong number of arguments for '" + r.Cmd + "' command"
|
||||
func (r *ArgNumErrReply) Error() string {
|
||||
return "ERR wrong number of arguments for '" + r.Cmd + "' command"
|
||||
}
|
||||
|
||||
// SyntaxErr
|
||||
type SyntaxErrReply struct {}
|
||||
type SyntaxErrReply struct{}
|
||||
|
||||
var syntaxErrBytes = []byte("-Err syntax error\r\n")
|
||||
|
||||
func (r *SyntaxErrReply)ToBytes()[]byte {
|
||||
return syntaxErrBytes
|
||||
func (r *SyntaxErrReply) ToBytes() []byte {
|
||||
return syntaxErrBytes
|
||||
}
|
||||
|
||||
func (r *SyntaxErrReply)Error()string {
|
||||
return "Err syntax error"
|
||||
func (r *SyntaxErrReply) Error() string {
|
||||
return "Err syntax error"
|
||||
}
|
||||
|
||||
// WrongTypeErr
|
||||
type WrongTypeErrReply struct {}
|
||||
type WrongTypeErrReply struct{}
|
||||
|
||||
var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
|
||||
|
||||
func (r *WrongTypeErrReply)ToBytes()[]byte {
|
||||
return wrongTypeErrBytes
|
||||
func (r *WrongTypeErrReply) ToBytes() []byte {
|
||||
return wrongTypeErrBytes
|
||||
}
|
||||
|
||||
func (r *WrongTypeErrReply)Error()string {
|
||||
return "WRONGTYPE Operation against a key holding the wrong kind of value"
|
||||
func (r *WrongTypeErrReply) Error() string {
|
||||
return "WRONGTYPE Operation against a key holding the wrong kind of value"
|
||||
}
|
||||
|
||||
// ProtocolErr
|
||||
|
||||
type ProtocolErrReply struct {
|
||||
Msg string
|
||||
Msg string
|
||||
}
|
||||
|
||||
func (r *ProtocolErrReply)ToBytes()[]byte {
|
||||
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
|
||||
func (r *ProtocolErrReply) ToBytes() []byte {
|
||||
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
|
||||
}
|
||||
|
||||
func (r *ProtocolErrReply) Error()string {
|
||||
return "ERR Protocol error: '" + r.Msg
|
||||
}
|
||||
func (r *ProtocolErrReply) Error() string {
|
||||
return "ERR Protocol error: '" + r.Msg
|
||||
}
|
||||
|
@@ -1,141 +1,140 @@
|
||||
package reply
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"strconv"
|
||||
"bytes"
|
||||
"github.com/HDT3213/godis/src/interface/redis"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
nullBulkReplyBytes = []byte("$-1")
|
||||
CRLF = "\r\n"
|
||||
nullBulkReplyBytes = []byte("$-1")
|
||||
CRLF = "\r\n"
|
||||
)
|
||||
|
||||
/* ---- Bulk Reply ---- */
|
||||
|
||||
type BulkReply struct {
|
||||
Arg []byte
|
||||
Arg []byte
|
||||
}
|
||||
|
||||
func MakeBulkReply(arg []byte) *BulkReply {
|
||||
return &BulkReply{
|
||||
Arg: arg,
|
||||
}
|
||||
return &BulkReply{
|
||||
Arg: arg,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *BulkReply) ToBytes() []byte {
|
||||
if len(r.Arg) == 0 {
|
||||
return nullBulkReplyBytes
|
||||
}
|
||||
return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF)
|
||||
if len(r.Arg) == 0 {
|
||||
return nullBulkReplyBytes
|
||||
}
|
||||
return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF)
|
||||
}
|
||||
|
||||
/* ---- Multi Bulk Reply ---- */
|
||||
|
||||
type MultiBulkReply struct {
|
||||
Args [][]byte
|
||||
Args [][]byte
|
||||
}
|
||||
|
||||
func MakeMultiBulkReply(args [][]byte) *MultiBulkReply {
|
||||
return &MultiBulkReply{
|
||||
Args: args,
|
||||
}
|
||||
return &MultiBulkReply{
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MultiBulkReply) ToBytes() []byte {
|
||||
argLen := len(r.Args)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
|
||||
for _, arg := range r.Args {
|
||||
if arg == nil {
|
||||
buf.WriteString("$-1" + CRLF)
|
||||
} else {
|
||||
buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF)
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
argLen := len(r.Args)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
|
||||
for _, arg := range r.Args {
|
||||
if arg == nil {
|
||||
buf.WriteString("$-1" + CRLF)
|
||||
} else {
|
||||
buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF)
|
||||
}
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
/* ---- Multi Raw Reply ---- */
|
||||
|
||||
type MultiRawReply struct {
|
||||
Args [][]byte
|
||||
Args [][]byte
|
||||
}
|
||||
|
||||
func MakeMultiRawReply(args [][]byte) *MultiRawReply {
|
||||
return &MultiRawReply{
|
||||
Args: args,
|
||||
}
|
||||
return &MultiRawReply{
|
||||
Args: args,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *MultiRawReply) ToBytes() []byte {
|
||||
argLen := len(r.Args)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
|
||||
for _, arg := range r.Args {
|
||||
buf.Write(arg)
|
||||
}
|
||||
return buf.Bytes()
|
||||
argLen := len(r.Args)
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
|
||||
for _, arg := range r.Args {
|
||||
buf.Write(arg)
|
||||
}
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
/* ---- Status Reply ---- */
|
||||
|
||||
type StatusReply struct {
|
||||
Status string
|
||||
Status string
|
||||
}
|
||||
|
||||
func MakeStatusReply(status string) *StatusReply {
|
||||
return &StatusReply{
|
||||
Status: status,
|
||||
}
|
||||
return &StatusReply{
|
||||
Status: status,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *StatusReply) ToBytes() []byte {
|
||||
return []byte("+" + r.Status + "\r\n")
|
||||
return []byte("+" + r.Status + "\r\n")
|
||||
}
|
||||
|
||||
/* ---- Int Reply ---- */
|
||||
|
||||
type IntReply struct {
|
||||
Code int64
|
||||
Code int64
|
||||
}
|
||||
|
||||
func MakeIntReply(code int64) *IntReply {
|
||||
return &IntReply{
|
||||
Code: code,
|
||||
}
|
||||
return &IntReply{
|
||||
Code: code,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *IntReply) ToBytes() []byte {
|
||||
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
|
||||
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
|
||||
}
|
||||
|
||||
|
||||
/* ---- Error Reply ---- */
|
||||
|
||||
type ErrorReply interface {
|
||||
Error() string
|
||||
ToBytes() []byte
|
||||
Error() string
|
||||
ToBytes() []byte
|
||||
}
|
||||
|
||||
type StandardErrReply struct {
|
||||
Status string
|
||||
Status string
|
||||
}
|
||||
|
||||
func MakeErrReply(status string) *StandardErrReply {
|
||||
return &StandardErrReply{
|
||||
Status: status,
|
||||
}
|
||||
return &StandardErrReply{
|
||||
Status: status,
|
||||
}
|
||||
}
|
||||
|
||||
func IsErrorReply(reply redis.Reply) bool {
|
||||
return reply.ToBytes()[0] == '-'
|
||||
return reply.ToBytes()[0] == '-'
|
||||
}
|
||||
|
||||
func (r *StandardErrReply) ToBytes() []byte {
|
||||
return []byte("-" + r.Status + "\r\n")
|
||||
return []byte("-" + r.Status + "\r\n")
|
||||
}
|
||||
|
||||
func (r *StandardErrReply) Error() string {
|
||||
return r.Status
|
||||
}
|
||||
return r.Status
|
||||
}
|
||||
|
@@ -1,95 +1,95 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// abstract of active client
|
||||
type Client struct {
|
||||
conn net.Conn
|
||||
conn net.Conn
|
||||
|
||||
// waiting util reply finished
|
||||
waitingReply wait.Wait
|
||||
// waiting util reply finished
|
||||
waitingReply wait.Wait
|
||||
|
||||
// is sending request in progress
|
||||
uploading atomic.AtomicBool
|
||||
// multi bulk msg lineCount - 1(first line)
|
||||
expectedArgsCount uint32
|
||||
// sent line count, exclude first line
|
||||
receivedCount uint32
|
||||
// sent lines, exclude first line
|
||||
args [][]byte
|
||||
// is sending request in progress
|
||||
uploading atomic.AtomicBool
|
||||
// multi bulk msg lineCount - 1(first line)
|
||||
expectedArgsCount uint32
|
||||
// sent line count, exclude first line
|
||||
receivedCount uint32
|
||||
// sent lines, exclude first line
|
||||
args [][]byte
|
||||
|
||||
// lock while server sending response
|
||||
mu sync.Mutex
|
||||
// lock while server sending response
|
||||
mu sync.Mutex
|
||||
|
||||
// subscribing channels
|
||||
subs map[string]bool
|
||||
// subscribing channels
|
||||
subs map[string]bool
|
||||
}
|
||||
|
||||
func (c *Client)Close()error {
|
||||
c.waitingReply.WaitWithTimeout(10 * time.Second)
|
||||
_ = c.conn.Close()
|
||||
return nil
|
||||
func (c *Client) Close() error {
|
||||
c.waitingReply.WaitWithTimeout(10 * time.Second)
|
||||
_ = c.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func MakeClient(conn net.Conn) *Client {
|
||||
return &Client{
|
||||
conn: conn,
|
||||
}
|
||||
return &Client{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client)Write(b []byte)error {
|
||||
if b == nil || len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
func (c *Client) Write(b []byte) error {
|
||||
if b == nil || len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
_, err := c.conn.Write(b)
|
||||
return err
|
||||
_, err := c.conn.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client)SubsChannel(channel string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
func (c *Client) SubsChannel(channel string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.subs == nil {
|
||||
c.subs = make(map[string]bool)
|
||||
}
|
||||
c.subs[channel] = true
|
||||
if c.subs == nil {
|
||||
c.subs = make(map[string]bool)
|
||||
}
|
||||
c.subs[channel] = true
|
||||
}
|
||||
|
||||
func (c *Client)UnSubsChannel(channel string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
func (c *Client) UnSubsChannel(channel string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.subs == nil {
|
||||
return
|
||||
}
|
||||
delete(c.subs, channel)
|
||||
if c.subs == nil {
|
||||
return
|
||||
}
|
||||
delete(c.subs, channel)
|
||||
}
|
||||
|
||||
func (c *Client)SubsCount()int {
|
||||
if c.subs == nil {
|
||||
return 0
|
||||
}
|
||||
return len(c.subs)
|
||||
func (c *Client) SubsCount() int {
|
||||
if c.subs == nil {
|
||||
return 0
|
||||
}
|
||||
return len(c.subs)
|
||||
}
|
||||
|
||||
func (c *Client)GetChannels()[]string {
|
||||
if c.subs == nil {
|
||||
return make([]string, 0)
|
||||
}
|
||||
channels := make([]string, len(c.subs))
|
||||
i := 0
|
||||
for channel := range c.subs {
|
||||
channels[i] = channel
|
||||
i++
|
||||
}
|
||||
return channels
|
||||
}
|
||||
func (c *Client) GetChannels() []string {
|
||||
if c.subs == nil {
|
||||
return make([]string, 0)
|
||||
}
|
||||
channels := make([]string, len(c.subs))
|
||||
i := 0
|
||||
for channel := range c.subs {
|
||||
channels[i] = channel
|
||||
i++
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
@@ -5,192 +5,192 @@ package server
|
||||
*/
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"github.com/HDT3213/godis/src/cluster"
|
||||
"github.com/HDT3213/godis/src/config"
|
||||
DBImpl "github.com/HDT3213/godis/src/db"
|
||||
"github.com/HDT3213/godis/src/interface/db"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"bufio"
|
||||
"context"
|
||||
"github.com/HDT3213/godis/src/cluster"
|
||||
"github.com/HDT3213/godis/src/config"
|
||||
DBImpl "github.com/HDT3213/godis/src/db"
|
||||
"github.com/HDT3213/godis/src/interface/db"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||
"github.com/HDT3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
|
||||
UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
activeConn sync.Map // *client -> placeholder
|
||||
db db.DB
|
||||
closing atomic.AtomicBool // refusing new client and new request
|
||||
activeConn sync.Map // *client -> placeholder
|
||||
db db.DB
|
||||
closing atomic.AtomicBool // refusing new client and new request
|
||||
}
|
||||
|
||||
func MakeHandler() *Handler {
|
||||
var db db.DB
|
||||
if config.Properties.Peers != nil &&
|
||||
len(config.Properties.Peers) > 0 {
|
||||
db = cluster.MakeCluster()
|
||||
} else {
|
||||
db = DBImpl.MakeDB()
|
||||
}
|
||||
return &Handler{
|
||||
db: db,
|
||||
}
|
||||
var db db.DB
|
||||
if config.Properties.Peers != nil &&
|
||||
len(config.Properties.Peers) > 0 {
|
||||
db = cluster.MakeCluster()
|
||||
} else {
|
||||
db = DBImpl.MakeDB()
|
||||
}
|
||||
return &Handler{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) closeClient(client *Client) {
|
||||
_ = client.Close()
|
||||
h.db.AfterClientClose(client)
|
||||
h.activeConn.Delete(client)
|
||||
_ = client.Close()
|
||||
h.db.AfterClientClose(client)
|
||||
h.activeConn.Delete(client)
|
||||
}
|
||||
|
||||
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if h.closing.Get() {
|
||||
// closing handler refuse new connection
|
||||
_ = conn.Close()
|
||||
}
|
||||
if h.closing.Get() {
|
||||
// closing handler refuse new connection
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
client := MakeClient(conn)
|
||||
h.activeConn.Store(client, 1)
|
||||
client := MakeClient(conn)
|
||||
h.activeConn.Store(client, 1)
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
var fixedLen int64 = 0
|
||||
var err error
|
||||
var msg []byte
|
||||
for {
|
||||
if fixedLen == 0 {
|
||||
msg, err = reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF ||
|
||||
err == io.ErrUnexpectedEOF ||
|
||||
strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
reader := bufio.NewReader(conn)
|
||||
var fixedLen int64 = 0
|
||||
var err error
|
||||
var msg []byte
|
||||
for {
|
||||
if fixedLen == 0 {
|
||||
msg, err = reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF ||
|
||||
err == io.ErrUnexpectedEOF ||
|
||||
strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
|
||||
// after client close
|
||||
h.closeClient(client)
|
||||
return // io error, disconnect with client
|
||||
}
|
||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
} else {
|
||||
msg = make([]byte, fixedLen+2)
|
||||
_, err = io.ReadFull(reader, msg)
|
||||
if err != nil {
|
||||
if err == io.EOF ||
|
||||
err == io.ErrUnexpectedEOF ||
|
||||
strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
// after client close
|
||||
h.closeClient(client)
|
||||
return // io error, disconnect with client
|
||||
}
|
||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
} else {
|
||||
msg = make([]byte, fixedLen+2)
|
||||
_, err = io.ReadFull(reader, msg)
|
||||
if err != nil {
|
||||
if err == io.EOF ||
|
||||
err == io.ErrUnexpectedEOF ||
|
||||
strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
|
||||
// after client close
|
||||
h.closeClient(client)
|
||||
return // io error, disconnect with client
|
||||
}
|
||||
if len(msg) == 0 ||
|
||||
msg[len(msg)-2] != '\r' ||
|
||||
msg[len(msg)-1] != '\n' {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
fixedLen = 0
|
||||
}
|
||||
// after client close
|
||||
h.closeClient(client)
|
||||
return // io error, disconnect with client
|
||||
}
|
||||
if len(msg) == 0 ||
|
||||
msg[len(msg)-2] != '\r' ||
|
||||
msg[len(msg)-1] != '\n' {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
fixedLen = 0
|
||||
}
|
||||
|
||||
if !client.uploading.Get() {
|
||||
// new request
|
||||
if msg[0] == '*' {
|
||||
// bulk multi msg
|
||||
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||
if err != nil {
|
||||
_, _ = client.conn.Write(UnknownErrReplyBytes)
|
||||
continue
|
||||
}
|
||||
client.waitingReply.Add(1)
|
||||
client.uploading.Set(true)
|
||||
client.expectedArgsCount = uint32(expectedLine)
|
||||
client.receivedCount = 0
|
||||
client.args = make([][]byte, expectedLine)
|
||||
} else {
|
||||
// text protocol
|
||||
// remove \r or \n or \r\n in the end of line
|
||||
str := strings.TrimSuffix(string(msg), "\n")
|
||||
str = strings.TrimSuffix(str, "\r")
|
||||
strs := strings.Split(str, " ")
|
||||
args := make([][]byte, len(strs))
|
||||
for i, s := range strs {
|
||||
args[i] = []byte(s)
|
||||
}
|
||||
if !client.uploading.Get() {
|
||||
// new request
|
||||
if msg[0] == '*' {
|
||||
// bulk multi msg
|
||||
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||
if err != nil {
|
||||
_, _ = client.conn.Write(UnknownErrReplyBytes)
|
||||
continue
|
||||
}
|
||||
client.waitingReply.Add(1)
|
||||
client.uploading.Set(true)
|
||||
client.expectedArgsCount = uint32(expectedLine)
|
||||
client.receivedCount = 0
|
||||
client.args = make([][]byte, expectedLine)
|
||||
} else {
|
||||
// text protocol
|
||||
// remove \r or \n or \r\n in the end of line
|
||||
str := strings.TrimSuffix(string(msg), "\n")
|
||||
str = strings.TrimSuffix(str, "\r")
|
||||
strs := strings.Split(str, " ")
|
||||
args := make([][]byte, len(strs))
|
||||
for i, s := range strs {
|
||||
args[i] = []byte(s)
|
||||
}
|
||||
|
||||
// send reply
|
||||
result := h.db.Exec(client, args)
|
||||
if result != nil {
|
||||
_ = client.Write(result.ToBytes())
|
||||
} else {
|
||||
_ = client.Write(UnknownErrReplyBytes)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// receive following part of a request
|
||||
line := msg[0 : len(msg)-2]
|
||||
if line[0] == '$' {
|
||||
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||
if err != nil {
|
||||
errReply := &reply.ProtocolErrReply{Msg: err.Error()}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
if fixedLen <= 0 {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
} else {
|
||||
client.args[client.receivedCount] = line
|
||||
client.receivedCount++
|
||||
}
|
||||
// send reply
|
||||
result := h.db.Exec(client, args)
|
||||
if result != nil {
|
||||
_ = client.Write(result.ToBytes())
|
||||
} else {
|
||||
_ = client.Write(UnknownErrReplyBytes)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// receive following part of a request
|
||||
line := msg[0 : len(msg)-2]
|
||||
if line[0] == '$' {
|
||||
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||
if err != nil {
|
||||
errReply := &reply.ProtocolErrReply{Msg: err.Error()}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
if fixedLen <= 0 {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
} else {
|
||||
client.args[client.receivedCount] = line
|
||||
client.receivedCount++
|
||||
}
|
||||
|
||||
// if sending finished
|
||||
if client.receivedCount == client.expectedArgsCount {
|
||||
client.uploading.Set(false) // finish sending progress
|
||||
// if sending finished
|
||||
if client.receivedCount == client.expectedArgsCount {
|
||||
client.uploading.Set(false) // finish sending progress
|
||||
|
||||
// send reply
|
||||
result := h.db.Exec(client, client.args)
|
||||
if result != nil {
|
||||
_ = client.Write(result.ToBytes())
|
||||
} else {
|
||||
_ = client.Write(UnknownErrReplyBytes)
|
||||
}
|
||||
// send reply
|
||||
result := h.db.Exec(client, client.args)
|
||||
if result != nil {
|
||||
_ = client.Write(result.ToBytes())
|
||||
} else {
|
||||
_ = client.Write(UnknownErrReplyBytes)
|
||||
}
|
||||
|
||||
// finish reply
|
||||
client.expectedArgsCount = 0
|
||||
client.receivedCount = 0
|
||||
client.args = nil
|
||||
client.waitingReply.Done()
|
||||
}
|
||||
}
|
||||
// finish reply
|
||||
client.expectedArgsCount = 0
|
||||
client.receivedCount = 0
|
||||
client.args = nil
|
||||
client.waitingReply.Done()
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Close() error {
|
||||
logger.Info("handler shuting down...")
|
||||
h.closing.Set(true)
|
||||
// TODO: concurrent wait
|
||||
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
||||
client := key.(*Client)
|
||||
_ = client.Close()
|
||||
return true
|
||||
})
|
||||
h.db.Close()
|
||||
return nil
|
||||
logger.Info("handler shuting down...")
|
||||
h.closing.Set(true)
|
||||
// TODO: concurrent wait
|
||||
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
||||
client := key.(*Client)
|
||||
_ = client.Close()
|
||||
return true
|
||||
})
|
||||
h.db.Close()
|
||||
return nil
|
||||
}
|
||||
|
119
src/tcp/echo.go
119
src/tcp/echo.go
@@ -5,79 +5,78 @@ package tcp
|
||||
*/
|
||||
|
||||
import (
|
||||
"net"
|
||||
"context"
|
||||
"bufio"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"sync"
|
||||
"io"
|
||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||
"time"
|
||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||
"bufio"
|
||||
"context"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EchoHandler struct {
|
||||
activeConn sync.Map
|
||||
closing atomic.AtomicBool
|
||||
activeConn sync.Map
|
||||
closing atomic.AtomicBool
|
||||
}
|
||||
|
||||
func MakeEchoHandler()(*EchoHandler) {
|
||||
return &EchoHandler{
|
||||
}
|
||||
func MakeEchoHandler() *EchoHandler {
|
||||
return &EchoHandler{}
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
Conn net.Conn
|
||||
Waiting wait.Wait
|
||||
Conn net.Conn
|
||||
Waiting wait.Wait
|
||||
}
|
||||
|
||||
func (c *Client)Close()error {
|
||||
c.Waiting.WaitWithTimeout(10 * time.Second)
|
||||
c.Conn.Close()
|
||||
return nil
|
||||
func (c *Client) Close() error {
|
||||
c.Waiting.WaitWithTimeout(10 * time.Second)
|
||||
c.Conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *EchoHandler)Handle(ctx context.Context, conn net.Conn) {
|
||||
if h.closing.Get() {
|
||||
// closing handler refuse new connection
|
||||
conn.Close()
|
||||
}
|
||||
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if h.closing.Get() {
|
||||
// closing handler refuse new connection
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
client := &Client {
|
||||
Conn: conn,
|
||||
}
|
||||
h.activeConn.Store(client, 1)
|
||||
client := &Client{
|
||||
Conn: conn,
|
||||
}
|
||||
h.activeConn.Store(client, 1)
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
// may occurs: client EOF, client timeout, server early close
|
||||
msg, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
logger.Info("connection close")
|
||||
h.activeConn.Delete(client)
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
client.Waiting.Add(1)
|
||||
//logger.Info("sleeping")
|
||||
//time.Sleep(10 * time.Second)
|
||||
b := []byte(msg)
|
||||
conn.Write(b)
|
||||
client.Waiting.Done()
|
||||
}
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
// may occurs: client EOF, client timeout, server early close
|
||||
msg, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
logger.Info("connection close")
|
||||
h.activeConn.Delete(client)
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
client.Waiting.Add(1)
|
||||
//logger.Info("sleeping")
|
||||
//time.Sleep(10 * time.Second)
|
||||
b := []byte(msg)
|
||||
conn.Write(b)
|
||||
client.Waiting.Done()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *EchoHandler)Close()error {
|
||||
logger.Info("handler shuting down...")
|
||||
h.closing.Set(true)
|
||||
// TODO: concurrent wait
|
||||
h.activeConn.Range(func(key interface{}, val interface{})bool {
|
||||
client := key.(*Client)
|
||||
client.Close()
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
func (h *EchoHandler) Close() error {
|
||||
logger.Info("handler shuting down...")
|
||||
h.closing.Set(true)
|
||||
// TODO: concurrent wait
|
||||
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
||||
client := key.(*Client)
|
||||
client.Close()
|
||||
return true
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
@@ -5,75 +5,74 @@ package tcp
|
||||
*/
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/interface/tcp"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/HDT3213/godis/src/interface/tcp"
|
||||
"github.com/HDT3213/godis/src/lib/logger"
|
||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Address string `yaml:"address"`
|
||||
MaxConnect uint32 `yaml:"max-connect"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
Address string `yaml:"address"`
|
||||
MaxConnect uint32 `yaml:"max-connect"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
}
|
||||
|
||||
func ListenAndServe(cfg *Config, handler tcp.Handler) {
|
||||
listener, err := net.Listen("tcp", cfg.Address)
|
||||
if err != nil {
|
||||
logger.Fatal(fmt.Sprintf("listen err: %v", err))
|
||||
}
|
||||
listener, err := net.Listen("tcp", cfg.Address)
|
||||
if err != nil {
|
||||
logger.Fatal(fmt.Sprintf("listen err: %v", err))
|
||||
}
|
||||
|
||||
// listen signal
|
||||
var closing atomic.AtomicBool
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
|
||||
go func() {
|
||||
sig := <-sigCh
|
||||
switch sig {
|
||||
case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
|
||||
logger.Info("shuting down...")
|
||||
closing.Set(true)
|
||||
_ = listener.Close() // listener.Accept() will return err immediately
|
||||
_ = handler.Close() // close connections
|
||||
}
|
||||
}()
|
||||
// listen signal
|
||||
var closing atomic.AtomicBool
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
|
||||
go func() {
|
||||
sig := <-sigCh
|
||||
switch sig {
|
||||
case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
|
||||
logger.Info("shuting down...")
|
||||
closing.Set(true)
|
||||
_ = listener.Close() // listener.Accept() will return err immediately
|
||||
_ = handler.Close() // close connections
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
// listen port
|
||||
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
|
||||
defer func() {
|
||||
// close during unexpected error
|
||||
_ = listener.Close()
|
||||
_ = handler.Close()
|
||||
}()
|
||||
ctx, _ := context.WithCancel(context.Background())
|
||||
var waitDone sync.WaitGroup
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if closing.Get() {
|
||||
logger.Info("waiting disconnect...")
|
||||
waitDone.Wait()
|
||||
return // handler will be closed by defer
|
||||
}
|
||||
logger.Error(fmt.Sprintf("accept err: %v", err))
|
||||
continue
|
||||
}
|
||||
// handle
|
||||
logger.Info("accept link")
|
||||
waitDone.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
waitDone.Done()
|
||||
}()
|
||||
handler.Handle(ctx, conn)
|
||||
}()
|
||||
}
|
||||
// listen port
|
||||
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
|
||||
defer func() {
|
||||
// close during unexpected error
|
||||
_ = listener.Close()
|
||||
_ = handler.Close()
|
||||
}()
|
||||
ctx, _ := context.WithCancel(context.Background())
|
||||
var waitDone sync.WaitGroup
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if closing.Get() {
|
||||
logger.Info("waiting disconnect...")
|
||||
waitDone.Wait()
|
||||
return // handler will be closed by defer
|
||||
}
|
||||
logger.Error(fmt.Sprintf("accept err: %v", err))
|
||||
continue
|
||||
}
|
||||
// handle
|
||||
logger.Info("accept link")
|
||||
waitDone.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
waitDone.Done()
|
||||
}()
|
||||
handler.Handle(ctx, conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user