mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 08:46:56 +08:00
reformat code
This commit is contained in:
17
README.md
17
README.md
@@ -2,11 +2,13 @@
|
|||||||
|
|
||||||
[中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md)
|
[中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md)
|
||||||
|
|
||||||
`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent middleware using golang.
|
`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent
|
||||||
|
middleware using golang.
|
||||||
|
|
||||||
Please be advised, NEVER think about using this in production environment.
|
Please be advised, NEVER think about using this in production environment.
|
||||||
|
|
||||||
This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence and server side cluster mode.
|
This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence
|
||||||
|
and server side cluster mode.
|
||||||
|
|
||||||
If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html).
|
If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html).
|
||||||
|
|
||||||
@@ -22,7 +24,7 @@ You could use redis-cli or other redis client to connect godis server, which lis
|
|||||||
|
|
||||||
The program will try to read config file path from environment variable `CONFIG`.
|
The program will try to read config file path from environment variable `CONFIG`.
|
||||||
|
|
||||||
If environment variable is not set, then the program try to read `redis.conf` in the working directory.
|
If environment variable is not set, then the program try to read `redis.conf` in the working directory.
|
||||||
|
|
||||||
If there is no such file, then the program will run with default config.
|
If there is no such file, then the program will run with default config.
|
||||||
|
|
||||||
@@ -35,8 +37,7 @@ peers localhost:7379,localhost:7389 // other node in cluster
|
|||||||
self localhost:6399 // self address
|
self localhost:6399 // self address
|
||||||
```
|
```
|
||||||
|
|
||||||
We provide node1.conf and node2.conf for demonstration.
|
We provide node1.conf and node2.conf for demonstration. use following command line to start a two-node-cluster:
|
||||||
use following command line to start a two-node-cluster:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CONFIG=node1.conf ./godis-darwin &
|
CONFIG=node1.conf ./godis-darwin &
|
||||||
@@ -151,7 +152,7 @@ Supported Commands:
|
|||||||
If you want to read my code in this repository, here is a simple guidance.
|
If you want to read my code in this repository, here is a simple guidance.
|
||||||
|
|
||||||
- cmd: only the entry point
|
- cmd: only the entry point
|
||||||
- config: config parser
|
- config: config parser
|
||||||
- interface: some interface definitions
|
- interface: some interface definitions
|
||||||
- lib: some utils, such as logger, sync utils and wildcard
|
- lib: some utils, such as logger, sync utils and wildcard
|
||||||
|
|
||||||
@@ -167,7 +168,7 @@ I suggest focusing on the following directories:
|
|||||||
- sortedset: a sorted set implements based on skiplist
|
- sortedset: a sorted set implements based on skiplist
|
||||||
- db: the implements of the redis db
|
- db: the implements of the redis db
|
||||||
- db.go: the basement of database
|
- db.go: the basement of database
|
||||||
- router.go: it find handler for commands
|
- router.go: it find handler for commands
|
||||||
- keys.go: handlers for keys commands
|
- keys.go: handlers for keys commands
|
||||||
- string.go: handlers for string commands
|
- string.go: handlers for string commands
|
||||||
- list.go: handlers for list commands
|
- list.go: handlers for list commands
|
||||||
@@ -176,7 +177,7 @@ I suggest focusing on the following directories:
|
|||||||
- sortedset.go: handlers for sorted set commands
|
- sortedset.go: handlers for sorted set commands
|
||||||
- pubsub.go: implements of publish / subscribe
|
- pubsub.go: implements of publish / subscribe
|
||||||
- aof.go: implements of AOF persistence and rewrite
|
- aof.go: implements of AOF persistence and rewrite
|
||||||
|
|
||||||
# License
|
# License
|
||||||
|
|
||||||
This project is licensed under the [GPL license](https://github.com/HDT3213/godis/blob/master/LICENSE).
|
This project is licensed under the [GPL license](https://github.com/HDT3213/godis/blob/master/LICENSE).
|
34
README_CN.md
34
README_CN.md
@@ -2,8 +2,8 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试
|
|||||||
|
|
||||||
**请注意:不要在生产环境使用使用此项目**
|
**请注意:不要在生产环境使用使用此项目**
|
||||||
|
|
||||||
Godis 实现了 Redis 的大多数功能,包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于 Godis 的信息。
|
Godis 实现了 Redis 的大多数功能,包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于
|
||||||
|
Godis 的信息。
|
||||||
|
|
||||||
# 运行 Godis
|
# 运行 Godis
|
||||||
|
|
||||||
@@ -149,19 +149,19 @@ redis-cli -p 6399
|
|||||||
- tcp: tcp 服务器实现
|
- tcp: tcp 服务器实现
|
||||||
- redis: redis 协议解析器
|
- redis: redis 协议解析器
|
||||||
- datastruct: redis 的各类数据结构实现
|
- datastruct: redis 的各类数据结构实现
|
||||||
- dict: hash 表
|
- dict: hash 表
|
||||||
- list: 链表
|
- list: 链表
|
||||||
- lock: 用于锁定 key 的锁组件
|
- lock: 用于锁定 key 的锁组件
|
||||||
- set: 基于hash表的集合
|
- set: 基于hash表的集合
|
||||||
- sortedset: 基于跳表实现的有序集合
|
- sortedset: 基于跳表实现的有序集合
|
||||||
- db: redis 存储引擎实现
|
- db: redis 存储引擎实现
|
||||||
- db.go: 引擎的基础功能
|
- db.go: 引擎的基础功能
|
||||||
- router.go: 将命令路由给响应的处理函数
|
- router.go: 将命令路由给响应的处理函数
|
||||||
- keys.go: del、ttl、expire 等通用命令实现
|
- keys.go: del、ttl、expire 等通用命令实现
|
||||||
- string.go: get、set 等字符串命令实现
|
- string.go: get、set 等字符串命令实现
|
||||||
- list.go: lpush、lindex 等列表命令实现
|
- list.go: lpush、lindex 等列表命令实现
|
||||||
- hash.go: hget、hset 等哈希表命令实现
|
- hash.go: hget、hset 等哈希表命令实现
|
||||||
- set.go: sadd 等集合命令实现
|
- set.go: sadd 等集合命令实现
|
||||||
- sortedset.go: zadd 等有序集合命令实现
|
- sortedset.go: zadd 等有序集合命令实现
|
||||||
- pubsub.go: 发布订阅命令实现
|
- pubsub.go: 发布订阅命令实现
|
||||||
- aof.go: aof持久化实现
|
- aof.go: aof持久化实现
|
@@ -1,45 +1,45 @@
|
|||||||
package cluster
|
package cluster
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/HDT3213/godis/src/redis/client"
|
"github.com/HDT3213/godis/src/redis/client"
|
||||||
"github.com/jolestar/go-commons-pool/v2"
|
"github.com/jolestar/go-commons-pool/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnectionFactory struct {
|
type ConnectionFactory struct {
|
||||||
Peer string
|
Peer string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ConnectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, error) {
|
func (f *ConnectionFactory) MakeObject(ctx context.Context) (*pool.PooledObject, error) {
|
||||||
c, err := client.MakeClient(f.Peer)
|
c, err := client.MakeClient(f.Peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.Start()
|
c.Start()
|
||||||
return pool.NewPooledObject(c), nil
|
return pool.NewPooledObject(c), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ConnectionFactory) DestroyObject(ctx context.Context, object *pool.PooledObject) error {
|
func (f *ConnectionFactory) DestroyObject(ctx context.Context, object *pool.PooledObject) error {
|
||||||
c, ok := object.Object.(*client.Client)
|
c, ok := object.Object.(*client.Client)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("type mismatch")
|
return errors.New("type mismatch")
|
||||||
}
|
}
|
||||||
c.Close()
|
c.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ConnectionFactory) ValidateObject(ctx context.Context, object *pool.PooledObject) bool {
|
func (f *ConnectionFactory) ValidateObject(ctx context.Context, object *pool.PooledObject) bool {
|
||||||
// do validate
|
// do validate
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ConnectionFactory) ActivateObject(ctx context.Context, object *pool.PooledObject) error {
|
func (f *ConnectionFactory) ActivateObject(ctx context.Context, object *pool.PooledObject) error {
|
||||||
// do activate
|
// do activate
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *ConnectionFactory) PassivateObject(ctx context.Context, object *pool.PooledObject) error {
|
func (f *ConnectionFactory) PassivateObject(ctx context.Context, object *pool.PooledObject) error {
|
||||||
// do passivate
|
// do passivate
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -1,99 +1,99 @@
|
|||||||
package cluster
|
package cluster
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'del' command")
|
||||||
}
|
}
|
||||||
keys := make([]string, len(args)-1)
|
keys := make([]string, len(args)-1)
|
||||||
for i := 1; i < len(args); i++ {
|
for i := 1; i < len(args); i++ {
|
||||||
keys[i-1] = string(args[i])
|
keys[i-1] = string(args[i])
|
||||||
}
|
}
|
||||||
groupMap := cluster.groupBy(keys)
|
groupMap := cluster.groupBy(keys)
|
||||||
if len(groupMap) == 1 { // do fast
|
if len(groupMap) == 1 { // do fast
|
||||||
for peer, group := range groupMap { // only one group
|
for peer, group := range groupMap { // only one group
|
||||||
return cluster.Relay(peer, c, makeArgs("DEL", group...))
|
return cluster.Relay(peer, c, makeArgs("DEL", group...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// prepare
|
// prepare
|
||||||
var errReply redis.Reply
|
var errReply redis.Reply
|
||||||
txId := cluster.idGenerator.NextId()
|
txId := cluster.idGenerator.NextId()
|
||||||
txIdStr := strconv.FormatInt(txId, 10)
|
txIdStr := strconv.FormatInt(txId, 10)
|
||||||
rollback := false
|
rollback := false
|
||||||
for peer, group := range groupMap {
|
for peer, group := range groupMap {
|
||||||
args := []string{txIdStr}
|
args := []string{txIdStr}
|
||||||
args = append(args, group...)
|
args = append(args, group...)
|
||||||
var resp redis.Reply
|
var resp redis.Reply
|
||||||
if peer == cluster.self {
|
if peer == cluster.self {
|
||||||
resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...))
|
resp = PrepareDel(cluster, c, makeArgs("PrepareDel", args...))
|
||||||
} else {
|
} else {
|
||||||
resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...))
|
resp = cluster.Relay(peer, c, makeArgs("PrepareDel", args...))
|
||||||
}
|
}
|
||||||
if reply.IsErrorReply(resp) {
|
if reply.IsErrorReply(resp) {
|
||||||
errReply = resp
|
errReply = resp
|
||||||
rollback = true
|
rollback = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var respList []redis.Reply
|
var respList []redis.Reply
|
||||||
if rollback {
|
if rollback {
|
||||||
// rollback
|
// rollback
|
||||||
RequestRollback(cluster, c, txId, groupMap)
|
RequestRollback(cluster, c, txId, groupMap)
|
||||||
} else {
|
} else {
|
||||||
// commit
|
// commit
|
||||||
respList, errReply = RequestCommit(cluster, c, txId, groupMap)
|
respList, errReply = RequestCommit(cluster, c, txId, groupMap)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
rollback = true
|
rollback = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !rollback {
|
if !rollback {
|
||||||
var deleted int64 = 0
|
var deleted int64 = 0
|
||||||
for _, resp := range respList {
|
for _, resp := range respList {
|
||||||
intResp := resp.(*reply.IntReply)
|
intResp := resp.(*reply.IntReply)
|
||||||
deleted += intResp.Code
|
deleted += intResp.Code
|
||||||
}
|
}
|
||||||
return reply.MakeIntReply(int64(deleted))
|
return reply.MakeIntReply(int64(deleted))
|
||||||
}
|
}
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
// args: PrepareDel id keys...
|
// args: PrepareDel id keys...
|
||||||
func PrepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func PrepareDel(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
if len(args) < 3 {
|
if len(args) < 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'preparedel' command")
|
||||||
}
|
}
|
||||||
txId := string(args[1])
|
txId := string(args[1])
|
||||||
keys := make([]string, 0, len(args)-2)
|
keys := make([]string, 0, len(args)-2)
|
||||||
for i := 2; i < len(args); i++ {
|
for i := 2; i < len(args); i++ {
|
||||||
arg := args[i]
|
arg := args[i]
|
||||||
keys = append(keys, string(arg))
|
keys = append(keys, string(arg))
|
||||||
}
|
}
|
||||||
txArgs := makeArgs("DEL", keys...) // actual args for cluster.db
|
txArgs := makeArgs("DEL", keys...) // actual args for cluster.db
|
||||||
tx := NewTransaction(cluster, c, txId, txArgs, keys)
|
tx := NewTransaction(cluster, c, txId, txArgs, keys)
|
||||||
cluster.transactions.Put(txId, tx)
|
cluster.transactions.Put(txId, tx)
|
||||||
err := tx.prepare()
|
err := tx.prepare()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply(err.Error())
|
return reply.MakeErrReply(err.Error())
|
||||||
}
|
}
|
||||||
return &reply.OkReply{}
|
return &reply.OkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// invoker should provide lock
|
// invoker should provide lock
|
||||||
func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
|
func CommitDel(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
|
||||||
keys := make([]string, len(tx.args))
|
keys := make([]string, len(tx.args))
|
||||||
for i, v := range tx.args {
|
for i, v := range tx.args {
|
||||||
keys[i] = string(v)
|
keys[i] = string(v)
|
||||||
}
|
}
|
||||||
keys = keys[1:]
|
keys = keys[1:]
|
||||||
|
|
||||||
deleted := cluster.db.Removes(keys...)
|
deleted := cluster.db.Removes(keys...)
|
||||||
if deleted > 0 {
|
if deleted > 0 {
|
||||||
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
|
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
|
||||||
}
|
}
|
||||||
return reply.MakeIntReply(int64(deleted))
|
return reply.MakeIntReply(int64(deleted))
|
||||||
}
|
}
|
||||||
|
@@ -1,85 +1,85 @@
|
|||||||
package idgenerator
|
package idgenerator
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
workerIdBits int64 = 5
|
workerIdBits int64 = 5
|
||||||
datacenterIdBits int64 = 5
|
datacenterIdBits int64 = 5
|
||||||
sequenceBits int64 = 12
|
sequenceBits int64 = 12
|
||||||
|
|
||||||
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
|
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
|
||||||
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
|
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
|
||||||
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
|
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
|
||||||
|
|
||||||
timeLeft uint8 = 22
|
timeLeft uint8 = 22
|
||||||
dataLeft uint8 = 17
|
dataLeft uint8 = 17
|
||||||
workLeft uint8 = 12
|
workLeft uint8 = 12
|
||||||
|
|
||||||
twepoch int64 = 1525705533000
|
twepoch int64 = 1525705533000
|
||||||
)
|
)
|
||||||
|
|
||||||
type IdGenerator struct {
|
type IdGenerator struct {
|
||||||
mu *sync.Mutex
|
mu *sync.Mutex
|
||||||
lastStamp int64
|
lastStamp int64
|
||||||
workerId int64
|
workerId int64
|
||||||
dataCenterId int64
|
dataCenterId int64
|
||||||
sequence int64
|
sequence int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeGenerator(cluster string, node string) *IdGenerator {
|
func MakeGenerator(cluster string, node string) *IdGenerator {
|
||||||
fnv64 := fnv.New64()
|
fnv64 := fnv.New64()
|
||||||
_, _ = fnv64.Write([]byte(cluster))
|
_, _ = fnv64.Write([]byte(cluster))
|
||||||
dataCenterId := int64(fnv64.Sum64())
|
dataCenterId := int64(fnv64.Sum64())
|
||||||
|
|
||||||
fnv64.Reset()
|
fnv64.Reset()
|
||||||
_, _ = fnv64.Write([]byte(node))
|
_, _ = fnv64.Write([]byte(node))
|
||||||
workerId := int64(fnv64.Sum64())
|
workerId := int64(fnv64.Sum64())
|
||||||
|
|
||||||
return &IdGenerator{
|
return &IdGenerator{
|
||||||
mu: &sync.Mutex{},
|
mu: &sync.Mutex{},
|
||||||
lastStamp: -1,
|
lastStamp: -1,
|
||||||
dataCenterId: dataCenterId,
|
dataCenterId: dataCenterId,
|
||||||
workerId: workerId,
|
workerId: workerId,
|
||||||
sequence: 1,
|
sequence: 1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *IdGenerator) getCurrentTime() int64 {
|
func (w *IdGenerator) getCurrentTime() int64 {
|
||||||
return time.Now().UnixNano() / 1e6
|
return time.Now().UnixNano() / 1e6
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *IdGenerator) NextId() int64 {
|
func (w *IdGenerator) NextId() int64 {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
timestamp := w.getCurrentTime()
|
timestamp := w.getCurrentTime()
|
||||||
if timestamp < w.lastStamp {
|
if timestamp < w.lastStamp {
|
||||||
log.Fatal("can not generate id")
|
log.Fatal("can not generate id")
|
||||||
}
|
}
|
||||||
if w.lastStamp == timestamp {
|
if w.lastStamp == timestamp {
|
||||||
w.sequence = (w.sequence + 1) & maxSequence
|
w.sequence = (w.sequence + 1) & maxSequence
|
||||||
if w.sequence == 0 {
|
if w.sequence == 0 {
|
||||||
for timestamp <= w.lastStamp {
|
for timestamp <= w.lastStamp {
|
||||||
timestamp = w.getCurrentTime()
|
timestamp = w.getCurrentTime()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
w.sequence = 0
|
w.sequence = 0
|
||||||
}
|
}
|
||||||
w.lastStamp = timestamp
|
w.lastStamp = timestamp
|
||||||
|
|
||||||
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
|
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *IdGenerator) tilNextMillis() int64 {
|
func (w *IdGenerator) tilNextMillis() int64 {
|
||||||
timestamp := w.getCurrentTime()
|
timestamp := w.getCurrentTime()
|
||||||
if timestamp <= w.lastStamp {
|
if timestamp <= w.lastStamp {
|
||||||
timestamp = w.getCurrentTime()
|
timestamp = w.getCurrentTime()
|
||||||
}
|
}
|
||||||
return timestamp
|
return timestamp
|
||||||
}
|
}
|
||||||
|
@@ -1,159 +1,159 @@
|
|||||||
package cluster
|
package cluster
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/HDT3213/godis/src/db"
|
"github.com/HDT3213/godis/src/db"
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func MGet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command")
|
||||||
}
|
}
|
||||||
keys := make([]string, len(args)-1)
|
keys := make([]string, len(args)-1)
|
||||||
for i := 1; i < len(args); i++ {
|
for i := 1; i < len(args); i++ {
|
||||||
keys[i-1] = string(args[i])
|
keys[i-1] = string(args[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
resultMap := make(map[string][]byte)
|
resultMap := make(map[string][]byte)
|
||||||
groupMap := cluster.groupBy(keys)
|
groupMap := cluster.groupBy(keys)
|
||||||
for peer, group := range groupMap {
|
for peer, group := range groupMap {
|
||||||
resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
|
resp := cluster.Relay(peer, c, makeArgs("MGET", group...))
|
||||||
if reply.IsErrorReply(resp) {
|
if reply.IsErrorReply(resp) {
|
||||||
errReply := resp.(reply.ErrorReply)
|
errReply := resp.(reply.ErrorReply)
|
||||||
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
|
return reply.MakeErrReply(fmt.Sprintf("ERR during get %s occurs: %v", group[0], errReply.Error()))
|
||||||
}
|
}
|
||||||
arrReply, _ := resp.(*reply.MultiBulkReply)
|
arrReply, _ := resp.(*reply.MultiBulkReply)
|
||||||
for i, v := range arrReply.Args {
|
for i, v := range arrReply.Args {
|
||||||
key := group[i]
|
key := group[i]
|
||||||
resultMap[key] = v
|
resultMap[key] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result := make([][]byte, len(keys))
|
result := make([][]byte, len(keys))
|
||||||
for i, k := range keys {
|
for i, k := range keys {
|
||||||
result[i] = resultMap[k]
|
result[i] = resultMap[k]
|
||||||
}
|
}
|
||||||
return reply.MakeMultiBulkReply(result)
|
return reply.MakeMultiBulkReply(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
// args: PrepareMSet id keys...
|
// args: PrepareMSet id keys...
|
||||||
func PrepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func PrepareMSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
if len(args) < 3 {
|
if len(args) < 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'preparemset' command")
|
||||||
}
|
}
|
||||||
txId := string(args[1])
|
txId := string(args[1])
|
||||||
size := (len(args) - 2) / 2
|
size := (len(args) - 2) / 2
|
||||||
keys := make([]string, size)
|
keys := make([]string, size)
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
keys[i] = string(args[2*i+2])
|
keys[i] = string(args[2*i+2])
|
||||||
}
|
}
|
||||||
|
|
||||||
txArgs := [][]byte{
|
txArgs := [][]byte{
|
||||||
[]byte("MSet"),
|
[]byte("MSet"),
|
||||||
} // actual args for cluster.db
|
} // actual args for cluster.db
|
||||||
txArgs = append(txArgs, args[2:]...)
|
txArgs = append(txArgs, args[2:]...)
|
||||||
tx := NewTransaction(cluster, c, txId, txArgs, keys)
|
tx := NewTransaction(cluster, c, txId, txArgs, keys)
|
||||||
cluster.transactions.Put(txId, tx)
|
cluster.transactions.Put(txId, tx)
|
||||||
err := tx.prepare()
|
err := tx.prepare()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply(err.Error())
|
return reply.MakeErrReply(err.Error())
|
||||||
}
|
}
|
||||||
return &reply.OkReply{}
|
return &reply.OkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// invoker should provide lock
|
// invoker should provide lock
|
||||||
func CommitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
|
func CommitMSet(cluster *Cluster, c redis.Connection, tx *Transaction) redis.Reply {
|
||||||
size := len(tx.args) / 2
|
size := len(tx.args) / 2
|
||||||
keys := make([]string, size)
|
keys := make([]string, size)
|
||||||
values := make([][]byte, size)
|
values := make([][]byte, size)
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
keys[i] = string(tx.args[2*i+1])
|
keys[i] = string(tx.args[2*i+1])
|
||||||
values[i] = tx.args[2*i+2]
|
values[i] = tx.args[2*i+2]
|
||||||
}
|
}
|
||||||
for i, key := range keys {
|
for i, key := range keys {
|
||||||
value := values[i]
|
value := values[i]
|
||||||
cluster.db.Put(key, &db.DataEntity{Data: value})
|
cluster.db.Put(key, &db.DataEntity{Data: value})
|
||||||
}
|
}
|
||||||
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
|
cluster.db.AddAof(reply.MakeMultiBulkReply(tx.args))
|
||||||
return &reply.OkReply{}
|
return &reply.OkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
argCount := len(args) - 1
|
argCount := len(args) - 1
|
||||||
if argCount%2 != 0 || argCount < 1 {
|
if argCount%2 != 0 || argCount < 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
|
||||||
}
|
}
|
||||||
|
|
||||||
size := argCount / 2
|
size := argCount / 2
|
||||||
keys := make([]string, size)
|
keys := make([]string, size)
|
||||||
valueMap := make(map[string]string)
|
valueMap := make(map[string]string)
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
keys[i] = string(args[2*i+1])
|
keys[i] = string(args[2*i+1])
|
||||||
valueMap[keys[i]] = string(args[2*i+2])
|
valueMap[keys[i]] = string(args[2*i+2])
|
||||||
}
|
}
|
||||||
|
|
||||||
groupMap := cluster.groupBy(keys)
|
groupMap := cluster.groupBy(keys)
|
||||||
if len(groupMap) == 1 { // do fast
|
if len(groupMap) == 1 { // do fast
|
||||||
for peer := range groupMap {
|
for peer := range groupMap {
|
||||||
return cluster.Relay(peer, c, args)
|
return cluster.Relay(peer, c, args)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//prepare
|
//prepare
|
||||||
var errReply redis.Reply
|
var errReply redis.Reply
|
||||||
txId := cluster.idGenerator.NextId()
|
txId := cluster.idGenerator.NextId()
|
||||||
txIdStr := strconv.FormatInt(txId, 10)
|
txIdStr := strconv.FormatInt(txId, 10)
|
||||||
rollback := false
|
rollback := false
|
||||||
for peer, group := range groupMap {
|
for peer, group := range groupMap {
|
||||||
peerArgs := []string{txIdStr}
|
peerArgs := []string{txIdStr}
|
||||||
for _, k := range group {
|
for _, k := range group {
|
||||||
peerArgs = append(peerArgs, k, valueMap[k])
|
peerArgs = append(peerArgs, k, valueMap[k])
|
||||||
}
|
}
|
||||||
var resp redis.Reply
|
var resp redis.Reply
|
||||||
if peer == cluster.self {
|
if peer == cluster.self {
|
||||||
resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...))
|
resp = PrepareMSet(cluster, c, makeArgs("PrepareMSet", peerArgs...))
|
||||||
} else {
|
} else {
|
||||||
resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...))
|
resp = cluster.Relay(peer, c, makeArgs("PrepareMSet", peerArgs...))
|
||||||
}
|
}
|
||||||
if reply.IsErrorReply(resp) {
|
if reply.IsErrorReply(resp) {
|
||||||
errReply = resp
|
errReply = resp
|
||||||
rollback = true
|
rollback = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if rollback {
|
if rollback {
|
||||||
// rollback
|
// rollback
|
||||||
RequestRollback(cluster, c, txId, groupMap)
|
RequestRollback(cluster, c, txId, groupMap)
|
||||||
} else {
|
} else {
|
||||||
_, errReply = RequestCommit(cluster, c, txId, groupMap)
|
_, errReply = RequestCommit(cluster, c, txId, groupMap)
|
||||||
rollback = errReply != nil
|
rollback = errReply != nil
|
||||||
}
|
}
|
||||||
if !rollback {
|
if !rollback {
|
||||||
return &reply.OkReply{}
|
return &reply.OkReply{}
|
||||||
}
|
}
|
||||||
return errReply
|
return errReply
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
argCount := len(args) - 1
|
argCount := len(args) - 1
|
||||||
if argCount%2 != 0 || argCount < 1 {
|
if argCount%2 != 0 || argCount < 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command")
|
||||||
}
|
}
|
||||||
var peer string
|
var peer string
|
||||||
size := argCount / 2
|
size := argCount / 2
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
key := string(args[2*i])
|
key := string(args[2*i])
|
||||||
currentPeer := cluster.peerPicker.Get(key)
|
currentPeer := cluster.peerPicker.Get(key)
|
||||||
if peer == "" {
|
if peer == "" {
|
||||||
peer = currentPeer
|
peer = currentPeer
|
||||||
} else {
|
} else {
|
||||||
if peer != currentPeer {
|
if peer != currentPeer {
|
||||||
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
|
return reply.MakeErrReply("ERR msetnx must within one slot in cluster mode")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cluster.Relay(peer, c, args)
|
return cluster.Relay(peer, c, args)
|
||||||
}
|
}
|
||||||
|
@@ -1,40 +1,39 @@
|
|||||||
package cluster
|
package cluster
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: support multiplex slots
|
// TODO: support multiplex slots
|
||||||
func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'rename' command")
|
||||||
}
|
}
|
||||||
src := string(args[1])
|
src := string(args[1])
|
||||||
dest := string(args[2])
|
dest := string(args[2])
|
||||||
|
|
||||||
srcPeer := cluster.peerPicker.Get(src)
|
srcPeer := cluster.peerPicker.Get(src)
|
||||||
destPeer := cluster.peerPicker.Get(dest)
|
destPeer := cluster.peerPicker.Get(dest)
|
||||||
|
|
||||||
if srcPeer != destPeer {
|
if srcPeer != destPeer {
|
||||||
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
||||||
}
|
}
|
||||||
return cluster.Relay(srcPeer, c, args)
|
return cluster.Relay(srcPeer, c, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'renamenx' command")
|
||||||
}
|
}
|
||||||
src := string(args[1])
|
src := string(args[1])
|
||||||
dest := string(args[2])
|
dest := string(args[2])
|
||||||
|
|
||||||
srcPeer := cluster.peerPicker.Get(src)
|
srcPeer := cluster.peerPicker.Get(src)
|
||||||
destPeer := cluster.peerPicker.Get(dest)
|
destPeer := cluster.peerPicker.Get(dest)
|
||||||
|
|
||||||
if srcPeer != destPeer {
|
if srcPeer != destPeer {
|
||||||
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
||||||
}
|
}
|
||||||
return cluster.Relay(srcPeer, c, args)
|
return cluster.Relay(srcPeer, c, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -3,106 +3,106 @@ package cluster
|
|||||||
import "github.com/HDT3213/godis/src/interface/redis"
|
import "github.com/HDT3213/godis/src/interface/redis"
|
||||||
|
|
||||||
func MakeRouter() map[string]CmdFunc {
|
func MakeRouter() map[string]CmdFunc {
|
||||||
routerMap := make(map[string]CmdFunc)
|
routerMap := make(map[string]CmdFunc)
|
||||||
routerMap["ping"] = Ping
|
routerMap["ping"] = Ping
|
||||||
|
|
||||||
routerMap["commit"] = Commit
|
routerMap["commit"] = Commit
|
||||||
routerMap["rollback"] = Rollback
|
routerMap["rollback"] = Rollback
|
||||||
routerMap["del"] = Del
|
routerMap["del"] = Del
|
||||||
routerMap["preparedel"] = PrepareDel
|
routerMap["preparedel"] = PrepareDel
|
||||||
routerMap["preparemset"] = PrepareMSet
|
routerMap["preparemset"] = PrepareMSet
|
||||||
|
|
||||||
routerMap["expire"] = defaultFunc
|
routerMap["expire"] = defaultFunc
|
||||||
routerMap["expireat"] = defaultFunc
|
routerMap["expireat"] = defaultFunc
|
||||||
routerMap["pexpire"] = defaultFunc
|
routerMap["pexpire"] = defaultFunc
|
||||||
routerMap["pexpireat"] = defaultFunc
|
routerMap["pexpireat"] = defaultFunc
|
||||||
routerMap["ttl"] = defaultFunc
|
routerMap["ttl"] = defaultFunc
|
||||||
routerMap["pttl"] = defaultFunc
|
routerMap["pttl"] = defaultFunc
|
||||||
routerMap["persist"] = defaultFunc
|
routerMap["persist"] = defaultFunc
|
||||||
routerMap["exists"] = defaultFunc
|
routerMap["exists"] = defaultFunc
|
||||||
routerMap["type"] = defaultFunc
|
routerMap["type"] = defaultFunc
|
||||||
routerMap["rename"] = Rename
|
routerMap["rename"] = Rename
|
||||||
routerMap["renamenx"] = RenameNx
|
routerMap["renamenx"] = RenameNx
|
||||||
|
|
||||||
routerMap["set"] = defaultFunc
|
routerMap["set"] = defaultFunc
|
||||||
routerMap["setnx"] = defaultFunc
|
routerMap["setnx"] = defaultFunc
|
||||||
routerMap["setex"] = defaultFunc
|
routerMap["setex"] = defaultFunc
|
||||||
routerMap["psetex"] = defaultFunc
|
routerMap["psetex"] = defaultFunc
|
||||||
routerMap["mset"] = MSet
|
routerMap["mset"] = MSet
|
||||||
routerMap["mget"] = MGet
|
routerMap["mget"] = MGet
|
||||||
routerMap["msetnx"] = MSetNX
|
routerMap["msetnx"] = MSetNX
|
||||||
routerMap["get"] = defaultFunc
|
routerMap["get"] = defaultFunc
|
||||||
routerMap["getset"] = defaultFunc
|
routerMap["getset"] = defaultFunc
|
||||||
routerMap["incr"] = defaultFunc
|
routerMap["incr"] = defaultFunc
|
||||||
routerMap["incrby"] = defaultFunc
|
routerMap["incrby"] = defaultFunc
|
||||||
routerMap["incrbyfloat"] = defaultFunc
|
routerMap["incrbyfloat"] = defaultFunc
|
||||||
routerMap["decr"] = defaultFunc
|
routerMap["decr"] = defaultFunc
|
||||||
routerMap["decrby"] = defaultFunc
|
routerMap["decrby"] = defaultFunc
|
||||||
|
|
||||||
routerMap["lpush"] = defaultFunc
|
routerMap["lpush"] = defaultFunc
|
||||||
routerMap["lpushx"] = defaultFunc
|
routerMap["lpushx"] = defaultFunc
|
||||||
routerMap["rpush"] = defaultFunc
|
routerMap["rpush"] = defaultFunc
|
||||||
routerMap["rpushx"] = defaultFunc
|
routerMap["rpushx"] = defaultFunc
|
||||||
routerMap["lpop"] = defaultFunc
|
routerMap["lpop"] = defaultFunc
|
||||||
routerMap["rpop"] = defaultFunc
|
routerMap["rpop"] = defaultFunc
|
||||||
//routerMap["rpoplpush"] = RPopLPush
|
//routerMap["rpoplpush"] = RPopLPush
|
||||||
routerMap["lrem"] = defaultFunc
|
routerMap["lrem"] = defaultFunc
|
||||||
routerMap["llen"] = defaultFunc
|
routerMap["llen"] = defaultFunc
|
||||||
routerMap["lindex"] = defaultFunc
|
routerMap["lindex"] = defaultFunc
|
||||||
routerMap["lset"] = defaultFunc
|
routerMap["lset"] = defaultFunc
|
||||||
routerMap["lrange"] = defaultFunc
|
routerMap["lrange"] = defaultFunc
|
||||||
|
|
||||||
routerMap["hset"] = defaultFunc
|
routerMap["hset"] = defaultFunc
|
||||||
routerMap["hsetnx"] = defaultFunc
|
routerMap["hsetnx"] = defaultFunc
|
||||||
routerMap["hget"] = defaultFunc
|
routerMap["hget"] = defaultFunc
|
||||||
routerMap["hexists"] = defaultFunc
|
routerMap["hexists"] = defaultFunc
|
||||||
routerMap["hdel"] = defaultFunc
|
routerMap["hdel"] = defaultFunc
|
||||||
routerMap["hlen"] = defaultFunc
|
routerMap["hlen"] = defaultFunc
|
||||||
routerMap["hmget"] = defaultFunc
|
routerMap["hmget"] = defaultFunc
|
||||||
routerMap["hmset"] = defaultFunc
|
routerMap["hmset"] = defaultFunc
|
||||||
routerMap["hkeys"] = defaultFunc
|
routerMap["hkeys"] = defaultFunc
|
||||||
routerMap["hvals"] = defaultFunc
|
routerMap["hvals"] = defaultFunc
|
||||||
routerMap["hgetall"] = defaultFunc
|
routerMap["hgetall"] = defaultFunc
|
||||||
routerMap["hincrby"] = defaultFunc
|
routerMap["hincrby"] = defaultFunc
|
||||||
routerMap["hincrbyfloat"] = defaultFunc
|
routerMap["hincrbyfloat"] = defaultFunc
|
||||||
|
|
||||||
routerMap["sadd"] = defaultFunc
|
routerMap["sadd"] = defaultFunc
|
||||||
routerMap["sismember"] = defaultFunc
|
routerMap["sismember"] = defaultFunc
|
||||||
routerMap["srem"] = defaultFunc
|
routerMap["srem"] = defaultFunc
|
||||||
routerMap["scard"] = defaultFunc
|
routerMap["scard"] = defaultFunc
|
||||||
routerMap["smembers"] = defaultFunc
|
routerMap["smembers"] = defaultFunc
|
||||||
routerMap["sinter"] = defaultFunc
|
routerMap["sinter"] = defaultFunc
|
||||||
routerMap["sinterstore"] = defaultFunc
|
routerMap["sinterstore"] = defaultFunc
|
||||||
routerMap["sunion"] = defaultFunc
|
routerMap["sunion"] = defaultFunc
|
||||||
routerMap["sunionstore"] = defaultFunc
|
routerMap["sunionstore"] = defaultFunc
|
||||||
routerMap["sdiff"] = defaultFunc
|
routerMap["sdiff"] = defaultFunc
|
||||||
routerMap["sdiffstore"] = defaultFunc
|
routerMap["sdiffstore"] = defaultFunc
|
||||||
routerMap["srandmember"] = defaultFunc
|
routerMap["srandmember"] = defaultFunc
|
||||||
|
|
||||||
routerMap["zadd"] = defaultFunc
|
routerMap["zadd"] = defaultFunc
|
||||||
routerMap["zscore"] = defaultFunc
|
routerMap["zscore"] = defaultFunc
|
||||||
routerMap["zincrby"] = defaultFunc
|
routerMap["zincrby"] = defaultFunc
|
||||||
routerMap["zrank"] = defaultFunc
|
routerMap["zrank"] = defaultFunc
|
||||||
routerMap["zcount"] = defaultFunc
|
routerMap["zcount"] = defaultFunc
|
||||||
routerMap["zrevrank"] = defaultFunc
|
routerMap["zrevrank"] = defaultFunc
|
||||||
routerMap["zcard"] = defaultFunc
|
routerMap["zcard"] = defaultFunc
|
||||||
routerMap["zrange"] = defaultFunc
|
routerMap["zrange"] = defaultFunc
|
||||||
routerMap["zrevrange"] = defaultFunc
|
routerMap["zrevrange"] = defaultFunc
|
||||||
routerMap["zrangebyscore"] = defaultFunc
|
routerMap["zrangebyscore"] = defaultFunc
|
||||||
routerMap["zrevrangebyscore"] = defaultFunc
|
routerMap["zrevrangebyscore"] = defaultFunc
|
||||||
routerMap["zrem"] = defaultFunc
|
routerMap["zrem"] = defaultFunc
|
||||||
routerMap["zremrangebyscore"] = defaultFunc
|
routerMap["zremrangebyscore"] = defaultFunc
|
||||||
routerMap["zremrangebyrank"] = defaultFunc
|
routerMap["zremrangebyrank"] = defaultFunc
|
||||||
|
|
||||||
//routerMap["flushdb"] = FlushDB
|
//routerMap["flushdb"] = FlushDB
|
||||||
//routerMap["flushall"] = FlushAll
|
//routerMap["flushall"] = FlushAll
|
||||||
//routerMap["keys"] = Keys
|
//routerMap["keys"] = Keys
|
||||||
|
|
||||||
return routerMap
|
return routerMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
key := string(args[1])
|
key := string(args[1])
|
||||||
peer := cluster.peerPicker.Get(key)
|
peer := cluster.peerPicker.Get(key)
|
||||||
return cluster.Relay(peer, c, args)
|
return cluster.Relay(peer, c, args)
|
||||||
}
|
}
|
||||||
|
@@ -1,188 +1,188 @@
|
|||||||
package cluster
|
package cluster
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/HDT3213/godis/src/db"
|
"github.com/HDT3213/godis/src/db"
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/lib/marshal/gob"
|
"github.com/HDT3213/godis/src/lib/marshal/gob"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Transaction struct {
|
type Transaction struct {
|
||||||
id string // transaction id
|
id string // transaction id
|
||||||
args [][]byte // cmd args
|
args [][]byte // cmd args
|
||||||
cluster *Cluster
|
cluster *Cluster
|
||||||
conn redis.Connection
|
conn redis.Connection
|
||||||
|
|
||||||
keys []string // related keys
|
keys []string // related keys
|
||||||
undoLog map[string][]byte // store data for undoLog
|
undoLog map[string][]byte // store data for undoLog
|
||||||
|
|
||||||
lockUntil time.Time
|
lockUntil time.Time
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
status int8
|
status int8
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxLockTime = 3 * time.Second
|
maxLockTime = 3 * time.Second
|
||||||
|
|
||||||
CreatedStatus = 0
|
CreatedStatus = 0
|
||||||
PreparedStatus = 1
|
PreparedStatus = 1
|
||||||
CommitedStatus = 2
|
CommitedStatus = 2
|
||||||
RollbackedStatus = 3
|
RollbackedStatus = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction {
|
func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction {
|
||||||
return &Transaction{
|
return &Transaction{
|
||||||
id: id,
|
id: id,
|
||||||
args: args,
|
args: args,
|
||||||
cluster: cluster,
|
cluster: cluster,
|
||||||
conn: c,
|
conn: c,
|
||||||
keys: keys,
|
keys: keys,
|
||||||
status: CreatedStatus,
|
status: CreatedStatus,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// t should contains Keys field
|
// t should contains Keys field
|
||||||
func (tx *Transaction) prepare() error {
|
func (tx *Transaction) prepare() error {
|
||||||
// lock keys
|
// lock keys
|
||||||
tx.cluster.db.Locks(tx.keys...)
|
tx.cluster.db.Locks(tx.keys...)
|
||||||
|
|
||||||
// use context to manage
|
// use context to manage
|
||||||
//tx.lockUntil = time.Now().Add(maxLockTime)
|
//tx.lockUntil = time.Now().Add(maxLockTime)
|
||||||
//ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil)
|
//ctx, cancel := context.WithDeadline(context.Background(), tx.lockUntil)
|
||||||
//tx.ctx = ctx
|
//tx.ctx = ctx
|
||||||
//tx.cancel = cancel
|
//tx.cancel = cancel
|
||||||
|
|
||||||
// build undoLog
|
// build undoLog
|
||||||
tx.undoLog = make(map[string][]byte)
|
tx.undoLog = make(map[string][]byte)
|
||||||
for _, key := range tx.keys {
|
for _, key := range tx.keys {
|
||||||
entity, ok := tx.cluster.db.Get(key)
|
entity, ok := tx.cluster.db.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
blob, err := gob.Marshal(entity)
|
blob, err := gob.Marshal(entity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tx.undoLog[key] = blob
|
tx.undoLog[key] = blob
|
||||||
} else {
|
} else {
|
||||||
tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback
|
tx.undoLog[key] = []byte{} // entity was nil, should be removed while rollback
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tx.status = PreparedStatus
|
tx.status = PreparedStatus
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Transaction) rollback() error {
|
func (tx *Transaction) rollback() error {
|
||||||
for key, blob := range tx.undoLog {
|
for key, blob := range tx.undoLog {
|
||||||
if len(blob) > 0 {
|
if len(blob) > 0 {
|
||||||
entity := &db.DataEntity{}
|
entity := &db.DataEntity{}
|
||||||
err := gob.UnMarshal(blob, entity)
|
err := gob.UnMarshal(blob, entity)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
tx.cluster.db.Put(key, entity)
|
tx.cluster.db.Put(key, entity)
|
||||||
} else {
|
} else {
|
||||||
tx.cluster.db.Remove(key)
|
tx.cluster.db.Remove(key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if tx.status != CommitedStatus {
|
if tx.status != CommitedStatus {
|
||||||
tx.cluster.db.UnLocks(tx.keys...)
|
tx.cluster.db.UnLocks(tx.keys...)
|
||||||
}
|
}
|
||||||
tx.status = RollbackedStatus
|
tx.status = RollbackedStatus
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// rollback local transaction
|
// rollback local transaction
|
||||||
func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func Rollback(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'rollback' command")
|
||||||
}
|
}
|
||||||
txId := string(args[1])
|
txId := string(args[1])
|
||||||
raw, ok := cluster.transactions.Get(txId)
|
raw, ok := cluster.transactions.Get(txId)
|
||||||
if !ok {
|
if !ok {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
tx, _ := raw.(*Transaction)
|
tx, _ := raw.(*Transaction)
|
||||||
err := tx.rollback()
|
err := tx.rollback()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply(err.Error())
|
return reply.MakeErrReply(err.Error())
|
||||||
}
|
}
|
||||||
return reply.MakeIntReply(1)
|
return reply.MakeIntReply(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// commit local transaction as a worker
|
// commit local transaction as a worker
|
||||||
func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'commit' command")
|
||||||
}
|
}
|
||||||
txId := string(args[1])
|
txId := string(args[1])
|
||||||
raw, ok := cluster.transactions.Get(txId)
|
raw, ok := cluster.transactions.Get(txId)
|
||||||
if !ok {
|
if !ok {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
tx, _ := raw.(*Transaction)
|
tx, _ := raw.(*Transaction)
|
||||||
|
|
||||||
// finish transaction
|
// finish transaction
|
||||||
defer func() {
|
defer func() {
|
||||||
cluster.db.UnLocks(tx.keys...)
|
cluster.db.UnLocks(tx.keys...)
|
||||||
tx.status = CommitedStatus
|
tx.status = CommitedStatus
|
||||||
//cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit
|
//cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cmd := strings.ToLower(string(tx.args[0]))
|
cmd := strings.ToLower(string(tx.args[0]))
|
||||||
var result redis.Reply
|
var result redis.Reply
|
||||||
if cmd == "del" {
|
if cmd == "del" {
|
||||||
result = CommitDel(cluster, c, tx)
|
result = CommitDel(cluster, c, tx)
|
||||||
} else if cmd == "mset" {
|
} else if cmd == "mset" {
|
||||||
result = CommitMSet(cluster, c, tx)
|
result = CommitMSet(cluster, c, tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
if reply.IsErrorReply(result) {
|
if reply.IsErrorReply(result) {
|
||||||
// failed
|
// failed
|
||||||
err2 := tx.rollback()
|
err2 := tx.rollback()
|
||||||
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
|
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
|
||||||
}
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// request all node commit transaction as leader
|
// request all node commit transaction as leader
|
||||||
func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) {
|
func RequestCommit(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) ([]redis.Reply, reply.ErrorReply) {
|
||||||
var errReply reply.ErrorReply
|
var errReply reply.ErrorReply
|
||||||
txIdStr := strconv.FormatInt(txId, 10)
|
txIdStr := strconv.FormatInt(txId, 10)
|
||||||
respList := make([]redis.Reply, 0, len(peers))
|
respList := make([]redis.Reply, 0, len(peers))
|
||||||
for peer := range peers {
|
for peer := range peers {
|
||||||
var resp redis.Reply
|
var resp redis.Reply
|
||||||
if peer == cluster.self {
|
if peer == cluster.self {
|
||||||
resp = Commit(cluster, c, makeArgs("commit", txIdStr))
|
resp = Commit(cluster, c, makeArgs("commit", txIdStr))
|
||||||
} else {
|
} else {
|
||||||
resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr))
|
resp = cluster.Relay(peer, c, makeArgs("commit", txIdStr))
|
||||||
}
|
}
|
||||||
if reply.IsErrorReply(resp) {
|
if reply.IsErrorReply(resp) {
|
||||||
errReply = resp.(reply.ErrorReply)
|
errReply = resp.(reply.ErrorReply)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
respList = append(respList, resp)
|
respList = append(respList, resp)
|
||||||
}
|
}
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
RequestRollback(cluster, c, txId, peers)
|
RequestRollback(cluster, c, txId, peers)
|
||||||
return nil, errReply
|
return nil, errReply
|
||||||
}
|
}
|
||||||
return respList, nil
|
return respList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// request all node rollback transaction as leader
|
// request all node rollback transaction as leader
|
||||||
func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) {
|
func RequestRollback(cluster *Cluster, c redis.Connection, txId int64, peers map[string][]string) {
|
||||||
txIdStr := strconv.FormatInt(txId, 10)
|
txIdStr := strconv.FormatInt(txId, 10)
|
||||||
for peer := range peers {
|
for peer := range peers {
|
||||||
if peer == cluster.self {
|
if peer == cluster.self {
|
||||||
Rollback(cluster, c, makeArgs("rollback", txIdStr))
|
Rollback(cluster, c, makeArgs("rollback", txIdStr))
|
||||||
} else {
|
} else {
|
||||||
cluster.Relay(peer, c, makeArgs("rollback", txIdStr))
|
cluster.Relay(peer, c, makeArgs("rollback", txIdStr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/HDT3213/godis/src/config"
|
"github.com/HDT3213/godis/src/config"
|
||||||
"github.com/HDT3213/godis/src/lib/logger"
|
"github.com/HDT3213/godis/src/lib/logger"
|
||||||
RedisServer "github.com/HDT3213/godis/src/redis/server"
|
RedisServer "github.com/HDT3213/godis/src/redis/server"
|
||||||
"github.com/HDT3213/godis/src/tcp"
|
"github.com/HDT3213/godis/src/tcp"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
@@ -23,6 +23,6 @@ func main() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
tcp.ListenAndServe(&tcp.Config{
|
tcp.ListenAndServe(&tcp.Config{
|
||||||
Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port),
|
Address: fmt.Sprintf("%s:%d", config.Properties.Bind, config.Properties.Port),
|
||||||
}, RedisServer.MakeHandler())
|
}, RedisServer.MakeHandler())
|
||||||
}
|
}
|
||||||
|
@@ -23,12 +23,12 @@ type PropertyHolder struct {
|
|||||||
var Properties *PropertyHolder
|
var Properties *PropertyHolder
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
// default config
|
// default config
|
||||||
Properties = &PropertyHolder{
|
Properties = &PropertyHolder{
|
||||||
Bind: "127.0.0.1",
|
Bind: "127.0.0.1",
|
||||||
Port: 6379,
|
Port: 6379,
|
||||||
AppendOnly: false,
|
AppendOnly: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig(configFilename string) *PropertyHolder {
|
func LoadConfig(configFilename string) *PropertyHolder {
|
||||||
|
@@ -1,16 +1,16 @@
|
|||||||
package dict
|
package dict
|
||||||
|
|
||||||
type Consumer func(key string, val interface{})bool
|
type Consumer func(key string, val interface{}) bool
|
||||||
|
|
||||||
type Dict interface {
|
type Dict interface {
|
||||||
Get(key string) (val interface{}, exists bool)
|
Get(key string) (val interface{}, exists bool)
|
||||||
Len() int
|
Len() int
|
||||||
Put(key string, val interface{}) (result int)
|
Put(key string, val interface{}) (result int)
|
||||||
PutIfAbsent(key string, val interface{}) (result int)
|
PutIfAbsent(key string, val interface{}) (result int)
|
||||||
PutIfExists(key string, val interface{}) (result int)
|
PutIfExists(key string, val interface{}) (result int)
|
||||||
Remove(key string) (result int)
|
Remove(key string) (result int)
|
||||||
ForEach(consumer Consumer)
|
ForEach(consumer Consumer)
|
||||||
Keys() []string
|
Keys() []string
|
||||||
RandomKeys(limit int) []string
|
RandomKeys(limit int) []string
|
||||||
RandomDistinctKeys(limit int) []string
|
RandomDistinctKeys(limit int) []string
|
||||||
}
|
}
|
||||||
|
@@ -1,244 +1,244 @@
|
|||||||
package dict
|
package dict
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPut(t *testing.T) {
|
func TestPut(t *testing.T) {
|
||||||
d := MakeConcurrent(0)
|
d := MakeConcurrent(0)
|
||||||
count := 100
|
count := 100
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(count)
|
wg.Add(count)
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
// insert
|
// insert
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
ret := d.Put(key, i)
|
ret := d.Put(key, i)
|
||||||
if ret != 1 { // insert 1
|
if ret != 1 { // insert 1
|
||||||
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
|
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
|
||||||
}
|
}
|
||||||
val, ok := d.Get(key)
|
val, ok := d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
|
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, ok := d.Get(key)
|
_, ok := d.Get(key)
|
||||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPutIfAbsent(t *testing.T) {
|
func TestPutIfAbsent(t *testing.T) {
|
||||||
d := MakeConcurrent(0)
|
d := MakeConcurrent(0)
|
||||||
count := 100
|
count := 100
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(count)
|
wg.Add(count)
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
// insert
|
// insert
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
ret := d.PutIfAbsent(key, i)
|
ret := d.PutIfAbsent(key, i)
|
||||||
if ret != 1 { // insert 1
|
if ret != 1 { // insert 1
|
||||||
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
|
t.Error("put test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key: " + key)
|
||||||
}
|
}
|
||||||
val, ok := d.Get(key)
|
val, ok := d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) +
|
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) +
|
||||||
", key: " + key)
|
", key: " + key)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, ok := d.Get(key)
|
_, ok := d.Get(key)
|
||||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||||
}
|
}
|
||||||
|
|
||||||
// update
|
// update
|
||||||
ret = d.PutIfAbsent(key, i * 10)
|
ret = d.PutIfAbsent(key, i*10)
|
||||||
if ret != 0 { // no update
|
if ret != 0 { // no update
|
||||||
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
||||||
}
|
}
|
||||||
val, ok = d.Get(key)
|
val, ok = d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
|
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal) + ", key: " + key)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Error("put test failed: expected true, actual: false, key: " + key)
|
t.Error("put test failed: expected true, actual: false, key: " + key)
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPutIfExists(t *testing.T) {
|
func TestPutIfExists(t *testing.T) {
|
||||||
d := MakeConcurrent(0)
|
d := MakeConcurrent(0)
|
||||||
count := 100
|
count := 100
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(count)
|
wg.Add(count)
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
// insert
|
// insert
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
// insert
|
// insert
|
||||||
ret := d.PutIfExists(key, i)
|
ret := d.PutIfExists(key, i)
|
||||||
if ret != 0 { // insert
|
if ret != 0 { // insert
|
||||||
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
|
||||||
}
|
}
|
||||||
|
|
||||||
d.Put(key, i)
|
d.Put(key, i)
|
||||||
ret = d.PutIfExists(key, 10 * i)
|
ret = d.PutIfExists(key, 10*i)
|
||||||
val, ok := d.Get(key)
|
val, ok := d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != 10 * i {
|
if intVal != 10*i {
|
||||||
t.Error("put test failed: expected " + strconv.Itoa(10 * i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("put test failed: expected " + strconv.Itoa(10*i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, ok := d.Get(key)
|
_, ok := d.Get(key)
|
||||||
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
t.Error("put test failed: expected true, actual: false, key: " + key + ", retry: " + strconv.FormatBool(ok))
|
||||||
}
|
}
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemove(t *testing.T) {
|
func TestRemove(t *testing.T) {
|
||||||
d := MakeConcurrent(0)
|
d := MakeConcurrent(0)
|
||||||
|
|
||||||
// remove head node
|
// remove head node
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
// insert
|
// insert
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
d.Put(key, i)
|
d.Put(key, i)
|
||||||
}
|
}
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
|
|
||||||
val, ok := d.Get(key)
|
val, ok := d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Error("put test failed: expected true, actual: false")
|
t.Error("put test failed: expected true, actual: false")
|
||||||
}
|
}
|
||||||
|
|
||||||
ret := d.Remove(key)
|
ret := d.Remove(key)
|
||||||
if ret != 1 {
|
if ret != 1 {
|
||||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key)
|
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret) + ", key:" + key)
|
||||||
}
|
}
|
||||||
_, ok = d.Get(key)
|
_, ok = d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
t.Error("remove test failed: expected true, actual false")
|
t.Error("remove test failed: expected true, actual false")
|
||||||
}
|
}
|
||||||
ret = d.Remove(key)
|
ret = d.Remove(key)
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove tail node
|
// remove tail node
|
||||||
d = MakeConcurrent(0)
|
d = MakeConcurrent(0)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
// insert
|
// insert
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
d.Put(key, i)
|
d.Put(key, i)
|
||||||
}
|
}
|
||||||
for i := 9; i >= 0; i-- {
|
for i := 9; i >= 0; i-- {
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
|
|
||||||
val, ok := d.Get(key)
|
val, ok := d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Error("put test failed: expected true, actual: false")
|
t.Error("put test failed: expected true, actual: false")
|
||||||
}
|
}
|
||||||
|
|
||||||
ret := d.Remove(key)
|
ret := d.Remove(key)
|
||||||
if ret != 1 {
|
if ret != 1 {
|
||||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
|
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
|
||||||
}
|
}
|
||||||
_, ok = d.Get(key)
|
_, ok = d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
t.Error("remove test failed: expected true, actual false")
|
t.Error("remove test failed: expected true, actual false")
|
||||||
}
|
}
|
||||||
ret = d.Remove(key)
|
ret = d.Remove(key)
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove middle node
|
// remove middle node
|
||||||
d = MakeConcurrent(0)
|
d = MakeConcurrent(0)
|
||||||
d.Put("head", 0)
|
d.Put("head", 0)
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
// insert
|
// insert
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
d.Put(key, i)
|
d.Put(key, i)
|
||||||
}
|
}
|
||||||
d.Put("tail", 0)
|
d.Put("tail", 0)
|
||||||
for i := 9; i >= 0; i-- {
|
for i := 9; i >= 0; i-- {
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
|
|
||||||
val, ok := d.Get(key)
|
val, ok := d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("put test failed: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
t.Error("put test failed: expected true, actual: false")
|
t.Error("put test failed: expected true, actual: false")
|
||||||
}
|
}
|
||||||
|
|
||||||
ret := d.Remove(key)
|
ret := d.Remove(key)
|
||||||
if ret != 1 {
|
if ret != 1 {
|
||||||
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
|
t.Error("remove test failed: expected result 1, actual: " + strconv.Itoa(ret))
|
||||||
}
|
}
|
||||||
_, ok = d.Get(key)
|
_, ok = d.Get(key)
|
||||||
if ok {
|
if ok {
|
||||||
t.Error("remove test failed: expected true, actual false")
|
t.Error("remove test failed: expected true, actual false")
|
||||||
}
|
}
|
||||||
ret = d.Remove(key)
|
ret = d.Remove(key)
|
||||||
if ret != 0 {
|
if ret != 0 {
|
||||||
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
t.Error("remove test failed: expected result 0 actual: " + strconv.Itoa(ret))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestForEach(t *testing.T) {
|
func TestForEach(t *testing.T) {
|
||||||
d := MakeConcurrent(0)
|
d := MakeConcurrent(0)
|
||||||
size := 100
|
size := 100
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
// insert
|
// insert
|
||||||
key := "k" + strconv.Itoa(i)
|
key := "k" + strconv.Itoa(i)
|
||||||
d.Put(key, i)
|
d.Put(key, i)
|
||||||
}
|
}
|
||||||
i := 0
|
i := 0
|
||||||
d.ForEach(func(key string, value interface{})bool {
|
d.ForEach(func(key string, value interface{}) bool {
|
||||||
intVal, _ := value.(int)
|
intVal, _ := value.(int)
|
||||||
expectedKey := "k" + strconv.Itoa(intVal)
|
expectedKey := "k" + strconv.Itoa(intVal)
|
||||||
if key != expectedKey {
|
if key != expectedKey {
|
||||||
t.Error("remove test failed: expected " + expectedKey + ", actual: " + key)
|
t.Error("remove test failed: expected " + expectedKey + ", actual: " + key)
|
||||||
}
|
}
|
||||||
i++
|
i++
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
if i != size {
|
if i != size {
|
||||||
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
|
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,108 +1,108 @@
|
|||||||
package dict
|
package dict
|
||||||
|
|
||||||
type SimpleDict struct {
|
type SimpleDict struct {
|
||||||
m map[string]interface{}
|
m map[string]interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeSimple() *SimpleDict {
|
func MakeSimple() *SimpleDict {
|
||||||
return &SimpleDict{
|
return &SimpleDict{
|
||||||
m: make(map[string]interface{}),
|
m: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) {
|
func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) {
|
||||||
val, ok := dict.m[key]
|
val, ok := dict.m[key]
|
||||||
return val, ok
|
return val, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) Len() int {
|
func (dict *SimpleDict) Len() int {
|
||||||
if dict.m == nil {
|
if dict.m == nil {
|
||||||
panic("m is nil")
|
panic("m is nil")
|
||||||
}
|
}
|
||||||
return len(dict.m)
|
return len(dict.m)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) Put(key string, val interface{}) (result int) {
|
func (dict *SimpleDict) Put(key string, val interface{}) (result int) {
|
||||||
_, existed := dict.m[key]
|
_, existed := dict.m[key]
|
||||||
dict.m[key] = val
|
dict.m[key] = val
|
||||||
if existed {
|
if existed {
|
||||||
return 0
|
return 0
|
||||||
} else {
|
} else {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) {
|
func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) {
|
||||||
_, existed := dict.m[key]
|
_, existed := dict.m[key]
|
||||||
if existed {
|
if existed {
|
||||||
return 0
|
return 0
|
||||||
} else {
|
} else {
|
||||||
dict.m[key] = val
|
dict.m[key] = val
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) {
|
func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) {
|
||||||
_, existed := dict.m[key]
|
_, existed := dict.m[key]
|
||||||
if existed {
|
if existed {
|
||||||
dict.m[key] = val
|
dict.m[key] = val
|
||||||
return 1
|
return 1
|
||||||
} else {
|
} else {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) Remove(key string) (result int) {
|
func (dict *SimpleDict) Remove(key string) (result int) {
|
||||||
_, existed := dict.m[key]
|
_, existed := dict.m[key]
|
||||||
delete(dict.m, key)
|
delete(dict.m, key)
|
||||||
if existed {
|
if existed {
|
||||||
return 1
|
return 1
|
||||||
} else {
|
} else {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) Keys() []string {
|
func (dict *SimpleDict) Keys() []string {
|
||||||
result := make([]string, len(dict.m))
|
result := make([]string, len(dict.m))
|
||||||
i := 0
|
i := 0
|
||||||
for k := range dict.m {
|
for k := range dict.m {
|
||||||
result[i] = k
|
result[i] = k
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) ForEach(consumer Consumer) {
|
func (dict *SimpleDict) ForEach(consumer Consumer) {
|
||||||
for k, v := range dict.m {
|
for k, v := range dict.m {
|
||||||
if !consumer(k, v) {
|
if !consumer(k, v) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) RandomKeys(limit int) []string {
|
func (dict *SimpleDict) RandomKeys(limit int) []string {
|
||||||
result := make([]string, limit)
|
result := make([]string, limit)
|
||||||
for i := 0; i < limit; i++ {
|
for i := 0; i < limit; i++ {
|
||||||
for k := range dict.m {
|
for k := range dict.m {
|
||||||
result[i] = k
|
result[i] = k
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dict *SimpleDict) RandomDistinctKeys(limit int) []string {
|
func (dict *SimpleDict) RandomDistinctKeys(limit int) []string {
|
||||||
size := limit
|
size := limit
|
||||||
if size > len(dict.m) {
|
if size > len(dict.m) {
|
||||||
size = len(dict.m)
|
size = len(dict.m)
|
||||||
}
|
}
|
||||||
result := make([]string, size)
|
result := make([]string, size)
|
||||||
i := 0
|
i := 0
|
||||||
for k := range dict.m {
|
for k := range dict.m {
|
||||||
if i == limit {
|
if i == limit {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
result[i] = k
|
result[i] = k
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
@@ -3,324 +3,324 @@ package list
|
|||||||
import "github.com/HDT3213/godis/src/datastruct/utils"
|
import "github.com/HDT3213/godis/src/datastruct/utils"
|
||||||
|
|
||||||
type LinkedList struct {
|
type LinkedList struct {
|
||||||
first *node
|
first *node
|
||||||
last *node
|
last *node
|
||||||
size int
|
size int
|
||||||
}
|
}
|
||||||
|
|
||||||
type node struct {
|
type node struct {
|
||||||
val interface{}
|
val interface{}
|
||||||
prev *node
|
prev *node
|
||||||
next * node
|
next *node
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)Add(val interface{}) {
|
func (list *LinkedList) Add(val interface{}) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
n := &node{
|
n := &node{
|
||||||
val: val,
|
val: val,
|
||||||
}
|
}
|
||||||
if list.last == nil {
|
if list.last == nil {
|
||||||
// empty list
|
// empty list
|
||||||
list.first = n
|
list.first = n
|
||||||
list.last = n
|
list.last = n
|
||||||
} else {
|
} else {
|
||||||
n.prev = list.last
|
n.prev = list.last
|
||||||
list.last.next = n
|
list.last.next = n
|
||||||
list.last = n
|
list.last = n
|
||||||
}
|
}
|
||||||
list.size++
|
list.size++
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)find(index int)(n *node) {
|
func (list *LinkedList) find(index int) (n *node) {
|
||||||
if index < list.size / 2 {
|
if index < list.size/2 {
|
||||||
n := list.first
|
n := list.first
|
||||||
for i := 0; i < index; i++ {
|
for i := 0; i < index; i++ {
|
||||||
n = n.next
|
n = n.next
|
||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
} else {
|
} else {
|
||||||
n := list.last
|
n := list.last
|
||||||
for i := list.size - 1; i > index; i-- {
|
for i := list.size - 1; i > index; i-- {
|
||||||
n = n.prev
|
n = n.prev
|
||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)Get(index int)(val interface{}) {
|
func (list *LinkedList) Get(index int) (val interface{}) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
if index < 0 || index >= list.size {
|
if index < 0 || index >= list.size {
|
||||||
panic("index out of bound")
|
panic("index out of bound")
|
||||||
}
|
}
|
||||||
return list.find(index).val
|
return list.find(index).val
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)Set(index int, val interface{}) {
|
func (list *LinkedList) Set(index int, val interface{}) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
if index < 0 || index > list.size {
|
if index < 0 || index > list.size {
|
||||||
panic("index out of bound")
|
panic("index out of bound")
|
||||||
}
|
}
|
||||||
n := list.find(index)
|
n := list.find(index)
|
||||||
n.val = val
|
n.val = val
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)Insert(index int, val interface{}) {
|
func (list *LinkedList) Insert(index int, val interface{}) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
if index < 0 || index > list.size {
|
if index < 0 || index > list.size {
|
||||||
panic("index out of bound")
|
panic("index out of bound")
|
||||||
}
|
}
|
||||||
|
|
||||||
if index == list.size {
|
if index == list.size {
|
||||||
list.Add(val)
|
list.Add(val)
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
// list is not empty
|
// list is not empty
|
||||||
pivot := list.find(index)
|
pivot := list.find(index)
|
||||||
n := &node{
|
n := &node{
|
||||||
val: val,
|
val: val,
|
||||||
prev: pivot.prev,
|
prev: pivot.prev,
|
||||||
next: pivot,
|
next: pivot,
|
||||||
}
|
}
|
||||||
if pivot.prev == nil {
|
if pivot.prev == nil {
|
||||||
list.first = n
|
list.first = n
|
||||||
} else {
|
} else {
|
||||||
pivot.prev.next = n
|
pivot.prev.next = n
|
||||||
}
|
}
|
||||||
pivot.prev = n
|
pivot.prev = n
|
||||||
list.size++
|
list.size++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)removeNode(n *node) {
|
func (list *LinkedList) removeNode(n *node) {
|
||||||
if n.prev == nil {
|
if n.prev == nil {
|
||||||
list.first = n.next
|
list.first = n.next
|
||||||
} else {
|
} else {
|
||||||
n.prev.next = n.next
|
n.prev.next = n.next
|
||||||
}
|
}
|
||||||
if n.next == nil {
|
if n.next == nil {
|
||||||
list.last = n.prev
|
list.last = n.prev
|
||||||
} else {
|
} else {
|
||||||
n.next.prev = n.prev
|
n.next.prev = n.prev
|
||||||
}
|
}
|
||||||
|
|
||||||
// for gc
|
// for gc
|
||||||
n.prev = nil
|
n.prev = nil
|
||||||
n.next = nil
|
n.next = nil
|
||||||
|
|
||||||
list.size--
|
list.size--
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)Remove(index int)(val interface{}) {
|
func (list *LinkedList) Remove(index int) (val interface{}) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
if index < 0 || index >= list.size {
|
if index < 0 || index >= list.size {
|
||||||
panic("index out of bound")
|
panic("index out of bound")
|
||||||
}
|
}
|
||||||
|
|
||||||
n := list.find(index)
|
n := list.find(index)
|
||||||
list.removeNode(n)
|
list.removeNode(n)
|
||||||
return n.val
|
return n.val
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)RemoveLast()(val interface{}) {
|
func (list *LinkedList) RemoveLast() (val interface{}) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
if list.last == nil {
|
if list.last == nil {
|
||||||
// empty list
|
// empty list
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
n := list.last
|
n := list.last
|
||||||
list.removeNode(n)
|
list.removeNode(n)
|
||||||
return n.val
|
return n.val
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)RemoveAllByVal(val interface{})int {
|
func (list *LinkedList) RemoveAllByVal(val interface{}) int {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
n := list.first
|
n := list.first
|
||||||
removed := 0
|
removed := 0
|
||||||
for n != nil {
|
for n != nil {
|
||||||
var toRemoveNode *node
|
var toRemoveNode *node
|
||||||
if utils.Equals(n.val, val) {
|
if utils.Equals(n.val, val) {
|
||||||
toRemoveNode = n
|
toRemoveNode = n
|
||||||
}
|
}
|
||||||
if n.next == nil {
|
if n.next == nil {
|
||||||
if toRemoveNode != nil {
|
if toRemoveNode != nil {
|
||||||
removed++
|
removed++
|
||||||
list.removeNode(toRemoveNode)
|
list.removeNode(toRemoveNode)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
n = n.next
|
n = n.next
|
||||||
}
|
}
|
||||||
if toRemoveNode != nil {
|
if toRemoveNode != nil {
|
||||||
removed++
|
removed++
|
||||||
list.removeNode(toRemoveNode)
|
list.removeNode(toRemoveNode)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* remove at most `count` values of the specified value in this list
|
* remove at most `count` values of the specified value in this list
|
||||||
* scan from left to right
|
* scan from left to right
|
||||||
*/
|
*/
|
||||||
func (list *LinkedList) RemoveByVal(val interface{}, count int)int {
|
func (list *LinkedList) RemoveByVal(val interface{}, count int) int {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
n := list.first
|
n := list.first
|
||||||
removed := 0
|
removed := 0
|
||||||
for n != nil {
|
for n != nil {
|
||||||
var toRemoveNode *node
|
var toRemoveNode *node
|
||||||
if utils.Equals(n.val, val) {
|
if utils.Equals(n.val, val) {
|
||||||
toRemoveNode = n
|
toRemoveNode = n
|
||||||
}
|
}
|
||||||
if n.next == nil {
|
if n.next == nil {
|
||||||
if toRemoveNode != nil {
|
if toRemoveNode != nil {
|
||||||
removed++
|
removed++
|
||||||
list.removeNode(toRemoveNode)
|
list.removeNode(toRemoveNode)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
n = n.next
|
n = n.next
|
||||||
}
|
}
|
||||||
|
|
||||||
if toRemoveNode != nil {
|
if toRemoveNode != nil {
|
||||||
removed++
|
removed++
|
||||||
list.removeNode(toRemoveNode)
|
list.removeNode(toRemoveNode)
|
||||||
}
|
}
|
||||||
if removed == count {
|
if removed == count {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int)int {
|
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
n := list.last
|
n := list.last
|
||||||
removed := 0
|
removed := 0
|
||||||
for n != nil {
|
for n != nil {
|
||||||
var toRemoveNode *node
|
var toRemoveNode *node
|
||||||
if utils.Equals(n.val, val) {
|
if utils.Equals(n.val, val) {
|
||||||
toRemoveNode = n
|
toRemoveNode = n
|
||||||
}
|
}
|
||||||
if n.prev == nil {
|
if n.prev == nil {
|
||||||
if toRemoveNode != nil {
|
if toRemoveNode != nil {
|
||||||
removed++
|
removed++
|
||||||
list.removeNode(toRemoveNode)
|
list.removeNode(toRemoveNode)
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
n = n.prev
|
n = n.prev
|
||||||
}
|
}
|
||||||
|
|
||||||
if toRemoveNode != nil {
|
if toRemoveNode != nil {
|
||||||
removed++
|
removed++
|
||||||
list.removeNode(toRemoveNode)
|
list.removeNode(toRemoveNode)
|
||||||
}
|
}
|
||||||
if removed == count {
|
if removed == count {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)Len()int {
|
func (list *LinkedList) Len() int {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
return list.size
|
return list.size
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)ForEach(consumer func(int, interface{})bool) {
|
func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
n := list.first
|
n := list.first
|
||||||
i := 0
|
i := 0
|
||||||
for n != nil {
|
for n != nil {
|
||||||
goNext := consumer(i, n.val)
|
goNext := consumer(i, n.val)
|
||||||
if !goNext || n.next == nil {
|
if !goNext || n.next == nil {
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
i++
|
i++
|
||||||
n = n.next
|
n = n.next
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)Contains(val interface{})bool {
|
func (list *LinkedList) Contains(val interface{}) bool {
|
||||||
contains := false
|
contains := false
|
||||||
list.ForEach(func(i int, actual interface{}) bool {
|
list.ForEach(func(i int, actual interface{}) bool {
|
||||||
if actual == val {
|
if actual == val {
|
||||||
contains = true
|
contains = true
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return contains
|
return contains
|
||||||
}
|
}
|
||||||
|
|
||||||
func (list *LinkedList)Range(start int, stop int)[]interface{} {
|
func (list *LinkedList) Range(start int, stop int) []interface{} {
|
||||||
if list == nil {
|
if list == nil {
|
||||||
panic("list is nil")
|
panic("list is nil")
|
||||||
}
|
}
|
||||||
if start < 0 || start >= list.size {
|
if start < 0 || start >= list.size {
|
||||||
panic("`start` out of range")
|
panic("`start` out of range")
|
||||||
}
|
}
|
||||||
if stop < start || stop > list.size {
|
if stop < start || stop > list.size {
|
||||||
panic("`stop` out of range")
|
panic("`stop` out of range")
|
||||||
}
|
}
|
||||||
|
|
||||||
sliceSize := stop - start
|
sliceSize := stop - start
|
||||||
slice := make([]interface{}, sliceSize)
|
slice := make([]interface{}, sliceSize)
|
||||||
n := list.first
|
n := list.first
|
||||||
i := 0
|
i := 0
|
||||||
sliceIndex := 0
|
sliceIndex := 0
|
||||||
for n != nil {
|
for n != nil {
|
||||||
if i >= start && i < stop {
|
if i >= start && i < stop {
|
||||||
slice[sliceIndex] = n.val
|
slice[sliceIndex] = n.val
|
||||||
sliceIndex++
|
sliceIndex++
|
||||||
} else if i >= stop {
|
} else if i >= stop {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if n.next == nil {
|
if n.next == nil {
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
i++
|
i++
|
||||||
n = n.next
|
n = n.next
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return slice
|
return slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Make(vals ...interface{}) *LinkedList {
|
func Make(vals ...interface{}) *LinkedList {
|
||||||
list := LinkedList{}
|
list := LinkedList{}
|
||||||
for _, v := range vals {
|
for _, v := range vals {
|
||||||
list.Add(v)
|
list.Add(v)
|
||||||
}
|
}
|
||||||
return &list
|
return &list
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeBytesList(vals ...[]byte) *LinkedList {
|
func MakeBytesList(vals ...[]byte) *LinkedList {
|
||||||
list := LinkedList{}
|
list := LinkedList{}
|
||||||
for _, v := range vals {
|
for _, v := range vals {
|
||||||
list.Add(v)
|
list.Add(v)
|
||||||
}
|
}
|
||||||
return &list
|
return &list
|
||||||
}
|
}
|
||||||
|
@@ -1,215 +1,215 @@
|
|||||||
package list
|
package list
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"strconv"
|
||||||
"strconv"
|
"strings"
|
||||||
"strings"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ToString(list *LinkedList) string {
|
func ToString(list *LinkedList) string {
|
||||||
arr := make([]string, list.size)
|
arr := make([]string, list.size)
|
||||||
list.ForEach(func(i int, v interface{}) bool {
|
list.ForEach(func(i int, v interface{}) bool {
|
||||||
integer, _ := v.(int)
|
integer, _ := v.(int)
|
||||||
arr[i] = strconv.Itoa(integer)
|
arr[i] = strconv.Itoa(integer)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return "[" + strings.Join(arr, ", ") + "]"
|
return "[" + strings.Join(arr, ", ") + "]"
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAdd(t *testing.T) {
|
func TestAdd(t *testing.T) {
|
||||||
list := Make()
|
list := Make()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
list.ForEach(func(i int, v interface{}) bool {
|
list.ForEach(func(i int, v interface{}) bool {
|
||||||
intVal, _ := v.(int)
|
intVal, _ := v.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGet(t *testing.T) {
|
func TestGet(t *testing.T) {
|
||||||
list := Make()
|
list := Make()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
v := list.Get(i)
|
v := list.Get(i)
|
||||||
k, _ := v.(int)
|
k, _ := v.(int)
|
||||||
if i != k {
|
if i != k {
|
||||||
t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k))
|
t.Error("get test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(k))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemove(t *testing.T) {
|
func TestRemove(t *testing.T) {
|
||||||
list := Make()
|
list := Make()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
for i := 9; i >= 0; i-- {
|
for i := 9; i >= 0; i-- {
|
||||||
list.Remove(i)
|
list.Remove(i)
|
||||||
if i != list.Len() {
|
if i != list.Len() {
|
||||||
t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len()))
|
t.Error("remove test fail: expected size " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(list.Len()))
|
||||||
}
|
}
|
||||||
list.ForEach(func(i int, v interface{}) bool {
|
list.ForEach(func(i int, v interface{}) bool {
|
||||||
intVal, _ := v.(int)
|
intVal, _ := v.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("remove test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveVal(t *testing.T) {
|
func TestRemoveVal(t *testing.T) {
|
||||||
list := Make()
|
list := Make()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
for index := 0; index < list.Len(); index++ {
|
for index := 0; index < list.Len(); index++ {
|
||||||
list.RemoveAllByVal(index)
|
list.RemoveAllByVal(index)
|
||||||
list.ForEach(func(i int, v interface{}) bool {
|
list.ForEach(func(i int, v interface{}) bool {
|
||||||
intVal, _ := v.(int)
|
intVal, _ := v.(int)
|
||||||
if intVal == index {
|
if intVal == index {
|
||||||
t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i))
|
t.Error("remove test fail: found " + strconv.Itoa(index) + " at index: " + strconv.Itoa(i))
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
list = Make()
|
list = Make()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.RemoveByVal(i, 1)
|
list.RemoveByVal(i, 1)
|
||||||
}
|
}
|
||||||
list.ForEach(func(i int, v interface{}) bool {
|
list.ForEach(func(i int, v interface{}) bool {
|
||||||
intVal, _ := v.(int)
|
intVal, _ := v.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.RemoveByVal(i, 1)
|
list.RemoveByVal(i, 1)
|
||||||
}
|
}
|
||||||
if list.Len() != 0 {
|
if list.Len() != 0 {
|
||||||
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
|
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
list = Make()
|
list = Make()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.ReverseRemoveByVal(i, 1)
|
list.ReverseRemoveByVal(i, 1)
|
||||||
}
|
}
|
||||||
list.ForEach(func(i int, v interface{}) bool {
|
list.ForEach(func(i int, v interface{}) bool {
|
||||||
intVal, _ := v.(int)
|
intVal, _ := v.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.ReverseRemoveByVal(i, 1)
|
list.ReverseRemoveByVal(i, 1)
|
||||||
}
|
}
|
||||||
if list.Len() != 0 {
|
if list.Len() != 0 {
|
||||||
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
|
t.Error("test fail: expected 0, actual: " + strconv.Itoa(list.Len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInsert(t *testing.T) {
|
func TestInsert(t *testing.T) {
|
||||||
list := Make()
|
list := Make()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Insert(i*2, i)
|
list.Insert(i*2, i)
|
||||||
|
|
||||||
list.ForEach(func(j int, v interface{}) bool {
|
list.ForEach(func(j int, v interface{}) bool {
|
||||||
var expected int
|
var expected int
|
||||||
if j < (i + 1) * 2 {
|
if j < (i+1)*2 {
|
||||||
if j%2 == 0 {
|
if j%2 == 0 {
|
||||||
expected = j / 2
|
expected = j / 2
|
||||||
} else {
|
} else {
|
||||||
expected = (j - 1) / 2
|
expected = (j - 1) / 2
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
expected = j - i - 1
|
expected = j - i - 1
|
||||||
}
|
}
|
||||||
actual, _ := list.Get(j).(int)
|
actual, _ := list.Get(j).(int)
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
|
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
|
|
||||||
for j := 0; j < list.Len(); j++ {
|
for j := 0; j < list.Len(); j++ {
|
||||||
var expected int
|
var expected int
|
||||||
if j < (i + 1) * 2 {
|
if j < (i+1)*2 {
|
||||||
if j%2 == 0 {
|
if j%2 == 0 {
|
||||||
expected = j / 2
|
expected = j / 2
|
||||||
} else {
|
} else {
|
||||||
expected = (j - 1) / 2
|
expected = (j - 1) / 2
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
expected = j - i - 1
|
expected = j - i - 1
|
||||||
}
|
}
|
||||||
actual, _ := list.Get(j).(int)
|
actual, _ := list.Get(j).(int)
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
|
t.Error("insert test fail: at i " + strconv.Itoa(i) + " expected " + strconv.Itoa(expected) + ", actual: " + strconv.Itoa(actual))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveLast(t *testing.T) {
|
func TestRemoveLast(t *testing.T) {
|
||||||
list := Make()
|
list := Make()
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
for i := 9; i >= 0; i-- {
|
for i := 9; i >= 0; i-- {
|
||||||
val := list.RemoveLast()
|
val := list.RemoveLast()
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
t.Error("add test fail: expected " + strconv.Itoa(i) + ", actual: " + strconv.Itoa(intVal))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRange(t *testing.T) {
|
func TestRange(t *testing.T) {
|
||||||
list := Make()
|
list := Make()
|
||||||
size := 10
|
size := 10
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
list.Add(i)
|
list.Add(i)
|
||||||
}
|
}
|
||||||
for start := 0; start < size; start++ {
|
for start := 0; start < size; start++ {
|
||||||
for stop := start; stop < size; stop++ {
|
for stop := start; stop < size; stop++ {
|
||||||
slice := list.Range(start, stop)
|
slice := list.Range(start, stop)
|
||||||
if len(slice) != stop - start {
|
if len(slice) != stop-start {
|
||||||
t.Error("expected " + strconv.Itoa(stop - start) + ", get: " + strconv.Itoa(len(slice)) +
|
t.Error("expected " + strconv.Itoa(stop-start) + ", get: " + strconv.Itoa(len(slice)) +
|
||||||
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
|
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
|
||||||
}
|
}
|
||||||
sliceIndex := 0
|
sliceIndex := 0
|
||||||
for i := start; i < stop; i++ {
|
for i := start; i < stop; i++ {
|
||||||
val := slice[sliceIndex]
|
val := slice[sliceIndex]
|
||||||
intVal, _ := val.(int)
|
intVal, _ := val.(int)
|
||||||
if intVal != i {
|
if intVal != i {
|
||||||
t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) +
|
t.Error("expected " + strconv.Itoa(i) + ", get: " + strconv.Itoa(intVal) +
|
||||||
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
|
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
|
||||||
}
|
}
|
||||||
sliceIndex++
|
sliceIndex++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,153 +1,152 @@
|
|||||||
package lock
|
package lock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
prime32 = uint32(16777619)
|
prime32 = uint32(16777619)
|
||||||
)
|
)
|
||||||
|
|
||||||
type Locks struct {
|
type Locks struct {
|
||||||
table []*sync.RWMutex
|
table []*sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func Make(tableSize int) *Locks {
|
func Make(tableSize int) *Locks {
|
||||||
table := make([]*sync.RWMutex, tableSize)
|
table := make([]*sync.RWMutex, tableSize)
|
||||||
for i := 0; i < tableSize; i++ {
|
for i := 0; i < tableSize; i++ {
|
||||||
table[i] = &sync.RWMutex{}
|
table[i] = &sync.RWMutex{}
|
||||||
}
|
}
|
||||||
return &Locks{
|
return &Locks{
|
||||||
table: table,
|
table: table,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func fnv32(key string) uint32 {
|
func fnv32(key string) uint32 {
|
||||||
hash := uint32(2166136261)
|
hash := uint32(2166136261)
|
||||||
for i := 0; i < len(key); i++ {
|
for i := 0; i < len(key); i++ {
|
||||||
hash *= prime32
|
hash *= prime32
|
||||||
hash ^= uint32(key[i])
|
hash ^= uint32(key[i])
|
||||||
}
|
}
|
||||||
return hash
|
return hash
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks) spread(hashCode uint32) uint32 {
|
func (locks *Locks) spread(hashCode uint32) uint32 {
|
||||||
if locks == nil {
|
if locks == nil {
|
||||||
panic("dict is nil")
|
panic("dict is nil")
|
||||||
}
|
}
|
||||||
tableSize := uint32(len(locks.table))
|
tableSize := uint32(len(locks.table))
|
||||||
return (tableSize - 1) & uint32(hashCode)
|
return (tableSize - 1) & uint32(hashCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks)Lock(key string) {
|
func (locks *Locks) Lock(key string) {
|
||||||
index := locks.spread(fnv32(key))
|
index := locks.spread(fnv32(key))
|
||||||
mu := locks.table[index]
|
mu := locks.table[index]
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks)RLock(key string) {
|
func (locks *Locks) RLock(key string) {
|
||||||
index := locks.spread(fnv32(key))
|
index := locks.spread(fnv32(key))
|
||||||
mu := locks.table[index]
|
mu := locks.table[index]
|
||||||
mu.RLock()
|
mu.RLock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks)UnLock(key string) {
|
func (locks *Locks) UnLock(key string) {
|
||||||
index := locks.spread(fnv32(key))
|
index := locks.spread(fnv32(key))
|
||||||
mu := locks.table[index]
|
mu := locks.table[index]
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks)RUnLock(key string) {
|
func (locks *Locks) RUnLock(key string) {
|
||||||
index := locks.spread(fnv32(key))
|
index := locks.spread(fnv32(key))
|
||||||
mu := locks.table[index]
|
mu := locks.table[index]
|
||||||
mu.RUnlock()
|
mu.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 {
|
func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 {
|
||||||
indexMap := make(map[uint32]bool)
|
indexMap := make(map[uint32]bool)
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
index := locks.spread(fnv32(key))
|
index := locks.spread(fnv32(key))
|
||||||
indexMap[index] = true
|
indexMap[index] = true
|
||||||
}
|
}
|
||||||
indices := make([]uint32, 0, len(indexMap))
|
indices := make([]uint32, 0, len(indexMap))
|
||||||
for index := range indexMap {
|
for index := range indexMap {
|
||||||
indices = append(indices, index)
|
indices = append(indices, index)
|
||||||
}
|
}
|
||||||
sort.Slice(indices, func(i, j int) bool {
|
sort.Slice(indices, func(i, j int) bool {
|
||||||
if !reverse {
|
if !reverse {
|
||||||
return indices[i] < indices[j]
|
return indices[i] < indices[j]
|
||||||
} else {
|
} else {
|
||||||
return indices[i] > indices[j]
|
return indices[i] > indices[j]
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
return indices
|
return indices
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks)Locks(keys ...string) {
|
func (locks *Locks) Locks(keys ...string) {
|
||||||
indices := locks.toLockIndices(keys, false)
|
indices := locks.toLockIndices(keys, false)
|
||||||
for _, index := range indices {
|
for _, index := range indices {
|
||||||
mu := locks.table[index]
|
mu := locks.table[index]
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks)RLocks(keys ...string) {
|
func (locks *Locks) RLocks(keys ...string) {
|
||||||
indices := locks.toLockIndices(keys, false)
|
indices := locks.toLockIndices(keys, false)
|
||||||
for _, index := range indices {
|
for _, index := range indices {
|
||||||
mu := locks.table[index]
|
mu := locks.table[index]
|
||||||
mu.RLock()
|
mu.RLock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (locks *Locks) UnLocks(keys ...string) {
|
||||||
func (locks *Locks)UnLocks(keys ...string) {
|
indices := locks.toLockIndices(keys, true)
|
||||||
indices := locks.toLockIndices(keys, true)
|
for _, index := range indices {
|
||||||
for _, index := range indices {
|
mu := locks.table[index]
|
||||||
mu := locks.table[index]
|
mu.Unlock()
|
||||||
mu.Unlock()
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (locks *Locks)RUnLocks(keys ...string) {
|
func (locks *Locks) RUnLocks(keys ...string) {
|
||||||
indices := locks.toLockIndices(keys, true)
|
indices := locks.toLockIndices(keys, true)
|
||||||
for _, index := range indices {
|
for _, index := range indices {
|
||||||
mu := locks.table[index]
|
mu := locks.table[index]
|
||||||
mu.RUnlock()
|
mu.RUnlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GoID() int {
|
func GoID() int {
|
||||||
var buf [64]byte
|
var buf [64]byte
|
||||||
n := runtime.Stack(buf[:], false)
|
n := runtime.Stack(buf[:], false)
|
||||||
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
|
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
|
||||||
id, err := strconv.Atoi(idField)
|
id, err := strconv.Atoi(idField)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
|
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
|
||||||
}
|
}
|
||||||
return id
|
return id
|
||||||
}
|
}
|
||||||
|
|
||||||
func debug(testing.T) {
|
func debug(testing.T) {
|
||||||
lm := Locks{}
|
lm := Locks{}
|
||||||
size := 10
|
size := 10
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(size)
|
wg.Add(size)
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
go func(i int) {
|
go func(i int) {
|
||||||
lm.Locks("1", "2")
|
lm.Locks("1", "2")
|
||||||
println("go: " + strconv.Itoa(GoID()))
|
println("go: " + strconv.Itoa(GoID()))
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
println("go: " + strconv.Itoa(GoID()))
|
println("go: " + strconv.Itoa(GoID()))
|
||||||
lm.UnLocks("1", "2")
|
lm.UnLocks("1", "2")
|
||||||
wg.Done()
|
wg.Done()
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
@@ -1,32 +1,32 @@
|
|||||||
package set
|
package set
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSet(t *testing.T) {
|
func TestSet(t *testing.T) {
|
||||||
size := 10
|
size := 10
|
||||||
set := Make()
|
set := Make()
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
set.Add(strconv.Itoa(i))
|
set.Add(strconv.Itoa(i))
|
||||||
}
|
}
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
ok := set.Has(strconv.Itoa(i))
|
ok := set.Has(strconv.Itoa(i))
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("expected true actual false, key: " + strconv.Itoa(i))
|
t.Error("expected true actual false, key: " + strconv.Itoa(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
ok := set.Remove(strconv.Itoa(i))
|
ok := set.Remove(strconv.Itoa(i))
|
||||||
if ok != 1 {
|
if ok != 1 {
|
||||||
t.Error("expected true actual false, key: " + strconv.Itoa(i))
|
t.Error("expected true actual false, key: " + strconv.Itoa(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
ok := set.Has(strconv.Itoa(i))
|
ok := set.Has(strconv.Itoa(i))
|
||||||
if ok {
|
if ok {
|
||||||
t.Error("expected false actual true, key: " + strconv.Itoa(i))
|
t.Error("expected false actual true, key: " + strconv.Itoa(i))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
package sortedset
|
package sortedset
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -14,78 +14,78 @@ import (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
const (
|
const (
|
||||||
negativeInf int8 = -1
|
negativeInf int8 = -1
|
||||||
positiveInf int8 = 1
|
positiveInf int8 = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
type ScoreBorder struct {
|
type ScoreBorder struct {
|
||||||
Inf int8
|
Inf int8
|
||||||
Value float64
|
Value float64
|
||||||
Exclude bool
|
Exclude bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// if max.greater(score) then the score is within the upper border
|
// if max.greater(score) then the score is within the upper border
|
||||||
// do not use min.greater()
|
// do not use min.greater()
|
||||||
func (border *ScoreBorder)greater(value float64)bool {
|
func (border *ScoreBorder) greater(value float64) bool {
|
||||||
if border.Inf == negativeInf {
|
if border.Inf == negativeInf {
|
||||||
return false
|
return false
|
||||||
} else if border.Inf == positiveInf {
|
} else if border.Inf == positiveInf {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if border.Exclude {
|
if border.Exclude {
|
||||||
return border.Value > value
|
return border.Value > value
|
||||||
} else {
|
} else {
|
||||||
return border.Value >= value
|
return border.Value >= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (border *ScoreBorder)less(value float64)bool {
|
func (border *ScoreBorder) less(value float64) bool {
|
||||||
if border.Inf == negativeInf {
|
if border.Inf == negativeInf {
|
||||||
return true
|
return true
|
||||||
} else if border.Inf == positiveInf {
|
} else if border.Inf == positiveInf {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if border.Exclude {
|
if border.Exclude {
|
||||||
return border.Value < value
|
return border.Value < value
|
||||||
} else {
|
} else {
|
||||||
return border.Value <= value
|
return border.Value <= value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var positiveInfBorder = &ScoreBorder {
|
var positiveInfBorder = &ScoreBorder{
|
||||||
Inf: positiveInf,
|
Inf: positiveInf,
|
||||||
}
|
}
|
||||||
|
|
||||||
var negativeInfBorder = &ScoreBorder {
|
var negativeInfBorder = &ScoreBorder{
|
||||||
Inf: negativeInf,
|
Inf: negativeInf,
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseScoreBorder(s string)(*ScoreBorder, error) {
|
func ParseScoreBorder(s string) (*ScoreBorder, error) {
|
||||||
if s == "inf" || s == "+inf" {
|
if s == "inf" || s == "+inf" {
|
||||||
return positiveInfBorder, nil
|
return positiveInfBorder, nil
|
||||||
}
|
}
|
||||||
if s == "-inf" {
|
if s == "-inf" {
|
||||||
return negativeInfBorder, nil
|
return negativeInfBorder, nil
|
||||||
}
|
}
|
||||||
if s[0] == '(' {
|
if s[0] == '(' {
|
||||||
value, err := strconv.ParseFloat(s[1:], 64)
|
value, err := strconv.ParseFloat(s[1:], 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("ERR min or max is not a float")
|
return nil, errors.New("ERR min or max is not a float")
|
||||||
}
|
}
|
||||||
return &ScoreBorder{
|
return &ScoreBorder{
|
||||||
Inf: 0,
|
Inf: 0,
|
||||||
Value: value,
|
Value: value,
|
||||||
Exclude: true,
|
Exclude: true,
|
||||||
}, nil
|
}, nil
|
||||||
} else {
|
} else {
|
||||||
value, err := strconv.ParseFloat(s, 64)
|
value, err := strconv.ParseFloat(s, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("ERR min or max is not a float")
|
return nil, errors.New("ERR min or max is not a float")
|
||||||
}
|
}
|
||||||
return &ScoreBorder{
|
return &ScoreBorder{
|
||||||
Inf: 0,
|
Inf: 0,
|
||||||
Value: value,
|
Value: value,
|
||||||
Exclude: false,
|
Exclude: false,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -3,130 +3,129 @@ package sortedset
|
|||||||
import "math/rand"
|
import "math/rand"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxLevel = 16
|
maxLevel = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
type Element struct {
|
type Element struct {
|
||||||
Member string
|
Member string
|
||||||
Score float64
|
Score float64
|
||||||
}
|
}
|
||||||
|
|
||||||
// level aspect of a Node
|
// level aspect of a Node
|
||||||
type Level struct {
|
type Level struct {
|
||||||
forward *Node // forward node has greater score
|
forward *Node // forward node has greater score
|
||||||
span int64
|
span int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type Node struct {
|
type Node struct {
|
||||||
Element
|
Element
|
||||||
backward *Node
|
backward *Node
|
||||||
level []*Level // level[0] is base level
|
level []*Level // level[0] is base level
|
||||||
}
|
}
|
||||||
|
|
||||||
type skiplist struct {
|
type skiplist struct {
|
||||||
header *Node
|
header *Node
|
||||||
tail *Node
|
tail *Node
|
||||||
length int64
|
length int64
|
||||||
level int16
|
level int16
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeNode(level int16, score float64, member string)*Node {
|
func makeNode(level int16, score float64, member string) *Node {
|
||||||
n := &Node{
|
n := &Node{
|
||||||
Element: Element{
|
Element: Element{
|
||||||
Score: score,
|
Score: score,
|
||||||
Member: member,
|
Member: member,
|
||||||
},
|
},
|
||||||
level: make([]*Level, level),
|
level: make([]*Level, level),
|
||||||
}
|
}
|
||||||
for i := range n.level {
|
for i := range n.level {
|
||||||
n.level[i] = new(Level)
|
n.level[i] = new(Level)
|
||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeSkiplist()*skiplist {
|
func makeSkiplist() *skiplist {
|
||||||
return &skiplist{
|
return &skiplist{
|
||||||
level: 1,
|
level: 1,
|
||||||
header: makeNode(maxLevel, 0, ""),
|
header: makeNode(maxLevel, 0, ""),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func randomLevel() int16 {
|
func randomLevel() int16 {
|
||||||
level := int16(1)
|
level := int16(1)
|
||||||
for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) {
|
for float32(rand.Int31()&0xFFFF) < (0.25 * 0xFFFF) {
|
||||||
level++
|
level++
|
||||||
}
|
}
|
||||||
if level < maxLevel {
|
if level < maxLevel {
|
||||||
return level
|
return level
|
||||||
}
|
}
|
||||||
return maxLevel
|
return maxLevel
|
||||||
}
|
}
|
||||||
|
|
||||||
func (skiplist *skiplist)insert(member string, score float64)*Node {
|
func (skiplist *skiplist) insert(member string, score float64) *Node {
|
||||||
update := make([]*Node, maxLevel) // link new node with node in `update`
|
update := make([]*Node, maxLevel) // link new node with node in `update`
|
||||||
rank := make([]int64, maxLevel)
|
rank := make([]int64, maxLevel)
|
||||||
|
|
||||||
// find position to insert
|
// find position to insert
|
||||||
node := skiplist.header
|
node := skiplist.header
|
||||||
for i := skiplist.level - 1; i >= 0; i-- {
|
for i := skiplist.level - 1; i >= 0; i-- {
|
||||||
if i == skiplist.level - 1 {
|
if i == skiplist.level-1 {
|
||||||
rank[i] = 0
|
rank[i] = 0
|
||||||
} else {
|
} else {
|
||||||
rank[i] = rank[i + 1] // store rank that is crossed to reach the insert position
|
rank[i] = rank[i+1] // store rank that is crossed to reach the insert position
|
||||||
}
|
}
|
||||||
if node.level[i] != nil {
|
if node.level[i] != nil {
|
||||||
// traverse the skip list
|
// traverse the skip list
|
||||||
for node.level[i].forward != nil &&
|
for node.level[i].forward != nil &&
|
||||||
(node.level[i].forward.Score < score ||
|
(node.level[i].forward.Score < score ||
|
||||||
(node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key
|
(node.level[i].forward.Score == score && node.level[i].forward.Member < member)) { // same score, different key
|
||||||
rank[i] += node.level[i].span
|
rank[i] += node.level[i].span
|
||||||
node = node.level[i].forward
|
node = node.level[i].forward
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
update[i] = node
|
update[i] = node
|
||||||
}
|
}
|
||||||
|
|
||||||
level := randomLevel()
|
level := randomLevel()
|
||||||
// extend skiplist level
|
// extend skiplist level
|
||||||
if level > skiplist.level {
|
if level > skiplist.level {
|
||||||
for i := skiplist.level; i < level; i++ {
|
for i := skiplist.level; i < level; i++ {
|
||||||
rank[i] = 0
|
rank[i] = 0
|
||||||
update[i] = skiplist.header
|
update[i] = skiplist.header
|
||||||
update[i].level[i].span = skiplist.length
|
update[i].level[i].span = skiplist.length
|
||||||
}
|
}
|
||||||
skiplist.level = level
|
skiplist.level = level
|
||||||
}
|
}
|
||||||
|
|
||||||
// make node and link into skiplist
|
// make node and link into skiplist
|
||||||
node = makeNode(level, score, member)
|
node = makeNode(level, score, member)
|
||||||
for i := int16(0); i < level; i++ {
|
for i := int16(0); i < level; i++ {
|
||||||
node.level[i].forward = update[i].level[i].forward
|
node.level[i].forward = update[i].level[i].forward
|
||||||
update[i].level[i].forward = node
|
update[i].level[i].forward = node
|
||||||
|
|
||||||
// update span covered by update[i] as node is inserted here
|
// update span covered by update[i] as node is inserted here
|
||||||
node.level[i].span = update[i].level[i].span - (rank[0] - rank[i])
|
node.level[i].span = update[i].level[i].span - (rank[0] - rank[i])
|
||||||
update[i].level[i].span = (rank[0] - rank[i]) + 1
|
update[i].level[i].span = (rank[0] - rank[i]) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// increment span for untouched levels
|
// increment span for untouched levels
|
||||||
for i := level; i < skiplist.level; i++ {
|
for i := level; i < skiplist.level; i++ {
|
||||||
update[i].level[i].span++
|
update[i].level[i].span++
|
||||||
}
|
}
|
||||||
|
|
||||||
// set backward node
|
// set backward node
|
||||||
if update[0] == skiplist.header {
|
if update[0] == skiplist.header {
|
||||||
node.backward = nil
|
node.backward = nil
|
||||||
} else {
|
} else {
|
||||||
node.backward = update[0]
|
node.backward = update[0]
|
||||||
}
|
}
|
||||||
if node.level[0].forward != nil {
|
if node.level[0].forward != nil {
|
||||||
node.level[0].forward.backward = node
|
node.level[0].forward.backward = node
|
||||||
} else {
|
} else {
|
||||||
skiplist.tail = node
|
skiplist.tail = node
|
||||||
}
|
}
|
||||||
skiplist.length++
|
skiplist.length++
|
||||||
return node
|
return node
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -134,212 +133,212 @@ func (skiplist *skiplist)insert(member string, score float64)*Node {
|
|||||||
* param update: backward node (of target)
|
* param update: backward node (of target)
|
||||||
*/
|
*/
|
||||||
func (skiplist *skiplist) removeNode(node *Node, update []*Node) {
|
func (skiplist *skiplist) removeNode(node *Node, update []*Node) {
|
||||||
for i := int16(0); i < skiplist.level; i++ {
|
for i := int16(0); i < skiplist.level; i++ {
|
||||||
if update[i].level[i].forward == node {
|
if update[i].level[i].forward == node {
|
||||||
update[i].level[i].span += node.level[i].span - 1
|
update[i].level[i].span += node.level[i].span - 1
|
||||||
update[i].level[i].forward = node.level[i].forward
|
update[i].level[i].forward = node.level[i].forward
|
||||||
} else {
|
} else {
|
||||||
update[i].level[i].span--
|
update[i].level[i].span--
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if node.level[0].forward != nil {
|
if node.level[0].forward != nil {
|
||||||
node.level[0].forward.backward = node.backward
|
node.level[0].forward.backward = node.backward
|
||||||
} else {
|
} else {
|
||||||
skiplist.tail = node.backward
|
skiplist.tail = node.backward
|
||||||
}
|
}
|
||||||
for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil {
|
for skiplist.level > 1 && skiplist.header.level[skiplist.level-1].forward == nil {
|
||||||
skiplist.level--
|
skiplist.level--
|
||||||
}
|
}
|
||||||
skiplist.length--
|
skiplist.length--
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* return: has found and removed node
|
* return: has found and removed node
|
||||||
*/
|
*/
|
||||||
func (skiplist *skiplist) remove(member string, score float64)bool {
|
func (skiplist *skiplist) remove(member string, score float64) bool {
|
||||||
/*
|
/*
|
||||||
* find backward node (of target) or last node of each level
|
* find backward node (of target) or last node of each level
|
||||||
* their forward need to be updated
|
* their forward need to be updated
|
||||||
*/
|
*/
|
||||||
update := make([]*Node, maxLevel)
|
update := make([]*Node, maxLevel)
|
||||||
node := skiplist.header
|
node := skiplist.header
|
||||||
for i := skiplist.level - 1; i >= 0; i-- {
|
for i := skiplist.level - 1; i >= 0; i-- {
|
||||||
for node.level[i].forward != nil &&
|
for node.level[i].forward != nil &&
|
||||||
(node.level[i].forward.Score < score ||
|
(node.level[i].forward.Score < score ||
|
||||||
(node.level[i].forward.Score == score &&
|
(node.level[i].forward.Score == score &&
|
||||||
node.level[i].forward.Member < member)) {
|
node.level[i].forward.Member < member)) {
|
||||||
node = node.level[i].forward
|
node = node.level[i].forward
|
||||||
}
|
}
|
||||||
update[i] = node
|
update[i] = node
|
||||||
}
|
}
|
||||||
node = node.level[0].forward
|
node = node.level[0].forward
|
||||||
if node != nil && score == node.Score && node.Member == member {
|
if node != nil && score == node.Score && node.Member == member {
|
||||||
skiplist.removeNode(node, update)
|
skiplist.removeNode(node, update)
|
||||||
// free x
|
// free x
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* return: 1 based rank, 0 means member not found
|
* return: 1 based rank, 0 means member not found
|
||||||
*/
|
*/
|
||||||
func (skiplist *skiplist) getRank(member string, score float64)int64 {
|
func (skiplist *skiplist) getRank(member string, score float64) int64 {
|
||||||
var rank int64 = 0
|
var rank int64 = 0
|
||||||
x := skiplist.header
|
x := skiplist.header
|
||||||
for i := skiplist.level - 1; i >= 0; i-- {
|
for i := skiplist.level - 1; i >= 0; i-- {
|
||||||
for x.level[i].forward != nil &&
|
for x.level[i].forward != nil &&
|
||||||
(x.level[i].forward.Score < score ||
|
(x.level[i].forward.Score < score ||
|
||||||
(x.level[i].forward.Score == score &&
|
(x.level[i].forward.Score == score &&
|
||||||
x.level[i].forward.Member <= member)) {
|
x.level[i].forward.Member <= member)) {
|
||||||
rank += x.level[i].span
|
rank += x.level[i].span
|
||||||
x = x.level[i].forward
|
x = x.level[i].forward
|
||||||
}
|
}
|
||||||
|
|
||||||
/* x might be equal to zsl->header, so test if obj is non-NULL */
|
/* x might be equal to zsl->header, so test if obj is non-NULL */
|
||||||
if x.Member == member {
|
if x.Member == member {
|
||||||
return rank
|
return rank
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* 1-based rank
|
* 1-based rank
|
||||||
*/
|
*/
|
||||||
func (skiplist *skiplist) getByRank(rank int64)*Node {
|
func (skiplist *skiplist) getByRank(rank int64) *Node {
|
||||||
var i int64 = 0
|
var i int64 = 0
|
||||||
n := skiplist.header
|
n := skiplist.header
|
||||||
// scan from top level
|
// scan from top level
|
||||||
for level := skiplist.level - 1; level >= 0; level-- {
|
for level := skiplist.level - 1; level >= 0; level-- {
|
||||||
for n.level[level].forward != nil && (i+n.level[level].span) <= rank {
|
for n.level[level].forward != nil && (i+n.level[level].span) <= rank {
|
||||||
i += n.level[level].span
|
i += n.level[level].span
|
||||||
n = n.level[level].forward
|
n = n.level[level].forward
|
||||||
}
|
}
|
||||||
if i == rank {
|
if i == rank {
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (skiplist *skiplist) hasInRange(min *ScoreBorder, max *ScoreBorder) bool {
|
func (skiplist *skiplist) hasInRange(min *ScoreBorder, max *ScoreBorder) bool {
|
||||||
// min & max = empty
|
// min & max = empty
|
||||||
if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) {
|
if min.Value > max.Value || (min.Value == max.Value && (min.Exclude || max.Exclude)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// min > tail
|
// min > tail
|
||||||
n := skiplist.tail
|
n := skiplist.tail
|
||||||
if n == nil || !min.less(n.Score) {
|
if n == nil || !min.less(n.Score) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// max < head
|
// max < head
|
||||||
n = skiplist.header.level[0].forward
|
n = skiplist.header.level[0].forward
|
||||||
if n == nil || !max.greater(n.Score) {
|
if n == nil || !max.greater(n.Score) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
|
func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
|
||||||
if !skiplist.hasInRange(min, max) {
|
if !skiplist.hasInRange(min, max) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
n := skiplist.header
|
n := skiplist.header
|
||||||
// scan from top level
|
// scan from top level
|
||||||
for level := skiplist.level - 1; level >= 0; level-- {
|
for level := skiplist.level - 1; level >= 0; level-- {
|
||||||
// if forward is not in range than move forward
|
// if forward is not in range than move forward
|
||||||
for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) {
|
for n.level[level].forward != nil && !min.less(n.level[level].forward.Score) {
|
||||||
n = n.level[level].forward
|
n = n.level[level].forward
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
/* This is an inner range, so the next node cannot be NULL. */
|
/* This is an inner range, so the next node cannot be NULL. */
|
||||||
n = n.level[0].forward
|
n = n.level[0].forward
|
||||||
if !max.greater(n.Score) {
|
if !max.greater(n.Score) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
|
func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node {
|
||||||
if !skiplist.hasInRange(min, max) {
|
if !skiplist.hasInRange(min, max) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
n := skiplist.header
|
n := skiplist.header
|
||||||
// scan from top level
|
// scan from top level
|
||||||
for level := skiplist.level - 1; level >= 0; level-- {
|
for level := skiplist.level - 1; level >= 0; level-- {
|
||||||
for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) {
|
for n.level[level].forward != nil && max.greater(n.level[level].forward.Score) {
|
||||||
n = n.level[level].forward
|
n = n.level[level].forward
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !min.less(n.Score) {
|
if !min.less(n.Score) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* return removed elements
|
* return removed elements
|
||||||
*/
|
*/
|
||||||
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)(removed []*Element) {
|
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder) (removed []*Element) {
|
||||||
update := make([]*Node, maxLevel)
|
update := make([]*Node, maxLevel)
|
||||||
removed = make([]*Element, 0)
|
removed = make([]*Element, 0)
|
||||||
// find backward nodes (of target range) or last node of each level
|
// find backward nodes (of target range) or last node of each level
|
||||||
node := skiplist.header
|
node := skiplist.header
|
||||||
for i := skiplist.level - 1; i >= 0; i-- {
|
for i := skiplist.level - 1; i >= 0; i-- {
|
||||||
for node.level[i].forward != nil {
|
for node.level[i].forward != nil {
|
||||||
if min.less(node.level[i].forward.Score) { // already in range
|
if min.less(node.level[i].forward.Score) { // already in range
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
node = node.level[i].forward
|
node = node.level[i].forward
|
||||||
}
|
}
|
||||||
update[i] = node
|
update[i] = node
|
||||||
}
|
}
|
||||||
|
|
||||||
// node is the first one within range
|
// node is the first one within range
|
||||||
node = node.level[0].forward
|
node = node.level[0].forward
|
||||||
|
|
||||||
// remove nodes in range
|
// remove nodes in range
|
||||||
for node != nil {
|
for node != nil {
|
||||||
if !max.greater(node.Score) { // already out of range
|
if !max.greater(node.Score) { // already out of range
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
next := node.level[0].forward
|
next := node.level[0].forward
|
||||||
removedElement := node.Element
|
removedElement := node.Element
|
||||||
removed = append(removed, &removedElement)
|
removed = append(removed, &removedElement)
|
||||||
skiplist.removeNode(node, update)
|
skiplist.removeNode(node, update)
|
||||||
node = next
|
node = next
|
||||||
}
|
}
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1-based rank, including start, exclude stop
|
// 1-based rank, including start, exclude stop
|
||||||
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64)(removed []*Element) {
|
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64) (removed []*Element) {
|
||||||
var i int64 = 0 // rank of iterator
|
var i int64 = 0 // rank of iterator
|
||||||
update := make([]*Node, maxLevel)
|
update := make([]*Node, maxLevel)
|
||||||
removed = make([]*Element, 0)
|
removed = make([]*Element, 0)
|
||||||
|
|
||||||
// scan from top level
|
// scan from top level
|
||||||
node := skiplist.header
|
node := skiplist.header
|
||||||
for level := skiplist.level - 1; level >= 0; level-- {
|
for level := skiplist.level - 1; level >= 0; level-- {
|
||||||
for node.level[level].forward != nil && (i+node.level[level].span) < start {
|
for node.level[level].forward != nil && (i+node.level[level].span) < start {
|
||||||
i += node.level[level].span
|
i += node.level[level].span
|
||||||
node = node.level[level].forward
|
node = node.level[level].forward
|
||||||
}
|
}
|
||||||
update[level] = node
|
update[level] = node
|
||||||
}
|
}
|
||||||
|
|
||||||
i++
|
i++
|
||||||
node = node.level[0].forward // first node in range
|
node = node.level[0].forward // first node in range
|
||||||
|
|
||||||
// remove nodes in range
|
// remove nodes in range
|
||||||
for node != nil && i < stop {
|
for node != nil && i < stop {
|
||||||
next := node.level[0].forward
|
next := node.level[0].forward
|
||||||
removedElement := node.Element
|
removedElement := node.Element
|
||||||
removed = append(removed, &removedElement)
|
removed = append(removed, &removedElement)
|
||||||
skiplist.removeNode(node, update)
|
skiplist.removeNode(node, update)
|
||||||
node = next
|
node = next
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
return removed
|
return removed
|
||||||
}
|
}
|
||||||
|
@@ -1,227 +1,226 @@
|
|||||||
package sortedset
|
package sortedset
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SortedSet struct {
|
type SortedSet struct {
|
||||||
dict map[string]*Element
|
dict map[string]*Element
|
||||||
skiplist *skiplist
|
skiplist *skiplist
|
||||||
}
|
}
|
||||||
|
|
||||||
func Make()*SortedSet {
|
func Make() *SortedSet {
|
||||||
return &SortedSet{
|
return &SortedSet{
|
||||||
dict: make(map[string]*Element),
|
dict: make(map[string]*Element),
|
||||||
skiplist: makeSkiplist(),
|
skiplist: makeSkiplist(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* return: has inserted new node
|
* return: has inserted new node
|
||||||
*/
|
*/
|
||||||
func (sortedSet *SortedSet)Add(member string, score float64)bool {
|
func (sortedSet *SortedSet) Add(member string, score float64) bool {
|
||||||
element, ok := sortedSet.dict[member]
|
element, ok := sortedSet.dict[member]
|
||||||
sortedSet.dict[member] = &Element{
|
sortedSet.dict[member] = &Element{
|
||||||
Member: member,
|
Member: member,
|
||||||
Score: score,
|
Score: score,
|
||||||
}
|
}
|
||||||
if ok {
|
if ok {
|
||||||
if score != element.Score {
|
if score != element.Score {
|
||||||
sortedSet.skiplist.remove(member, score)
|
sortedSet.skiplist.remove(member, score)
|
||||||
sortedSet.skiplist.insert(member, score)
|
sortedSet.skiplist.insert(member, score)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
} else {
|
} else {
|
||||||
sortedSet.skiplist.insert(member, score)
|
sortedSet.skiplist.insert(member, score)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sortedSet *SortedSet) Len()int64 {
|
func (sortedSet *SortedSet) Len() int64 {
|
||||||
return int64(len(sortedSet.dict))
|
return int64(len(sortedSet.dict))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) {
|
func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) {
|
||||||
element, ok = sortedSet.dict[member]
|
element, ok = sortedSet.dict[member]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
return element, true
|
return element, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sortedSet *SortedSet) Remove(member string)bool {
|
func (sortedSet *SortedSet) Remove(member string) bool {
|
||||||
v, ok := sortedSet.dict[member]
|
v, ok := sortedSet.dict[member]
|
||||||
if ok {
|
if ok {
|
||||||
sortedSet.skiplist.remove(member, v.Score)
|
sortedSet.skiplist.remove(member, v.Score)
|
||||||
delete(sortedSet.dict, member)
|
delete(sortedSet.dict, member)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* get 0-based rank
|
* get 0-based rank
|
||||||
*/
|
*/
|
||||||
func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) {
|
func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) {
|
||||||
element, ok := sortedSet.dict[member]
|
element, ok := sortedSet.dict[member]
|
||||||
if !ok {
|
if !ok {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
r := sortedSet.skiplist.getRank(member, element.Score)
|
r := sortedSet.skiplist.getRank(member, element.Score)
|
||||||
if desc {
|
if desc {
|
||||||
r = sortedSet.skiplist.length - r
|
r = sortedSet.skiplist.length - r
|
||||||
} else {
|
} else {
|
||||||
r--
|
r--
|
||||||
}
|
}
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* traverse [start, stop), 0-based rank
|
* traverse [start, stop), 0-based rank
|
||||||
*/
|
*/
|
||||||
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element)bool) {
|
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element) bool) {
|
||||||
size := int64(sortedSet.Len())
|
size := int64(sortedSet.Len())
|
||||||
if start < 0 || start >= size {
|
if start < 0 || start >= size {
|
||||||
panic("illegal start " + strconv.FormatInt(start, 10))
|
panic("illegal start " + strconv.FormatInt(start, 10))
|
||||||
}
|
}
|
||||||
if stop < start || stop > size {
|
if stop < start || stop > size {
|
||||||
panic("illegal end " + strconv.FormatInt(stop, 10))
|
panic("illegal end " + strconv.FormatInt(stop, 10))
|
||||||
}
|
}
|
||||||
|
|
||||||
// find start node
|
// find start node
|
||||||
var node *Node
|
var node *Node
|
||||||
if desc {
|
if desc {
|
||||||
node = sortedSet.skiplist.tail
|
node = sortedSet.skiplist.tail
|
||||||
if start > 0 {
|
if start > 0 {
|
||||||
node = sortedSet.skiplist.getByRank(int64(size - start))
|
node = sortedSet.skiplist.getByRank(int64(size - start))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
node = sortedSet.skiplist.header.level[0].forward
|
node = sortedSet.skiplist.header.level[0].forward
|
||||||
if start > 0 {
|
if start > 0 {
|
||||||
node = sortedSet.skiplist.getByRank(int64(start + 1))
|
node = sortedSet.skiplist.getByRank(int64(start + 1))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sliceSize := int(stop - start)
|
sliceSize := int(stop - start)
|
||||||
for i := 0; i < sliceSize; i++ {
|
for i := 0; i < sliceSize; i++ {
|
||||||
if !consumer(&node.Element) {
|
if !consumer(&node.Element) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if desc {
|
if desc {
|
||||||
node = node.backward
|
node = node.backward
|
||||||
} else {
|
} else {
|
||||||
node = node.level[0].forward
|
node = node.level[0].forward
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* return [start, stop), 0-based rank
|
* return [start, stop), 0-based rank
|
||||||
* assert start in [0, size), stop in [start, size]
|
* assert start in [0, size), stop in [start, size]
|
||||||
*/
|
*/
|
||||||
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool)[]*Element {
|
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element {
|
||||||
sliceSize := int(stop - start)
|
sliceSize := int(stop - start)
|
||||||
slice := make([]*Element, sliceSize)
|
slice := make([]*Element, sliceSize)
|
||||||
i := 0
|
i := 0
|
||||||
sortedSet.ForEach(start, stop, desc, func(element *Element)bool {
|
sortedSet.ForEach(start, stop, desc, func(element *Element) bool {
|
||||||
slice[i] = element
|
slice[i] = element
|
||||||
i++
|
i++
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return slice
|
return slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder)int64 {
|
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 {
|
||||||
var i int64 = 0
|
var i int64 = 0
|
||||||
// ascending order
|
// ascending order
|
||||||
sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool {
|
sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool {
|
||||||
gtMin := min.less(element.Score) // greater than min
|
gtMin := min.less(element.Score) // greater than min
|
||||||
if !gtMin {
|
if !gtMin {
|
||||||
// has not into range, continue foreach
|
// has not into range, continue foreach
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
ltMax := max.greater(element.Score) // less than max
|
ltMax := max.greater(element.Score) // less than max
|
||||||
if !ltMax {
|
if !ltMax {
|
||||||
// break through score border, break foreach
|
// break through score border, break foreach
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// gtMin && ltMax
|
// gtMin && ltMax
|
||||||
i++
|
i++
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) {
|
func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) {
|
||||||
// find start node
|
// find start node
|
||||||
var node *Node
|
var node *Node
|
||||||
if desc {
|
if desc {
|
||||||
node = sortedSet.skiplist.getLastInScoreRange(min, max)
|
node = sortedSet.skiplist.getLastInScoreRange(min, max)
|
||||||
} else {
|
} else {
|
||||||
node = sortedSet.skiplist.getFirstInScoreRange(min, max)
|
node = sortedSet.skiplist.getFirstInScoreRange(min, max)
|
||||||
}
|
}
|
||||||
|
|
||||||
for node != nil && offset > 0 {
|
for node != nil && offset > 0 {
|
||||||
if desc {
|
if desc {
|
||||||
node = node.backward
|
node = node.backward
|
||||||
} else {
|
} else {
|
||||||
node = node.level[0].forward
|
node = node.level[0].forward
|
||||||
}
|
}
|
||||||
offset--
|
offset--
|
||||||
}
|
}
|
||||||
|
|
||||||
// A negative limit returns all elements from the offset
|
// A negative limit returns all elements from the offset
|
||||||
for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ {
|
for i := 0; (i < int(limit) || limit < 0) && node != nil; i++ {
|
||||||
if !consumer(&node.Element) {
|
if !consumer(&node.Element) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if desc {
|
if desc {
|
||||||
node = node.backward
|
node = node.backward
|
||||||
} else {
|
} else {
|
||||||
node = node.level[0].forward
|
node = node.level[0].forward
|
||||||
}
|
}
|
||||||
if node == nil {
|
if node == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
gtMin := min.less(node.Element.Score) // greater than min
|
gtMin := min.less(node.Element.Score) // greater than min
|
||||||
ltMax := max.greater(node.Element.Score)
|
ltMax := max.greater(node.Element.Score)
|
||||||
if !gtMin || !ltMax {
|
if !gtMin || !ltMax {
|
||||||
break // break through score border
|
break // break through score border
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* param limit: <0 means no limit
|
* param limit: <0 means no limit
|
||||||
*/
|
*/
|
||||||
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool)[]*Element {
|
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool) []*Element {
|
||||||
if limit == 0 || offset < 0{
|
if limit == 0 || offset < 0 {
|
||||||
return make([]*Element, 0)
|
return make([]*Element, 0)
|
||||||
}
|
}
|
||||||
slice := make([]*Element, 0)
|
slice := make([]*Element, 0)
|
||||||
sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool {
|
sortedSet.ForEachByScore(min, max, offset, limit, desc, func(element *Element) bool {
|
||||||
slice = append(slice, element)
|
slice = append(slice, element)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return slice
|
return slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder)int64 {
|
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) int64 {
|
||||||
removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
|
removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
|
||||||
for _, element := range removed {
|
for _, element := range removed {
|
||||||
delete(sortedSet.dict, element.Member)
|
delete(sortedSet.dict, element.Member)
|
||||||
}
|
}
|
||||||
return int64(len(removed))
|
return int64(len(removed))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* 0-based rank, [start, stop)
|
* 0-based rank, [start, stop)
|
||||||
*/
|
*/
|
||||||
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64)int64 {
|
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 {
|
||||||
removed := sortedSet.skiplist.RemoveRangeByRank(start + 1, stop + 1)
|
removed := sortedSet.skiplist.RemoveRangeByRank(start+1, stop+1)
|
||||||
for _, element := range removed {
|
for _, element := range removed {
|
||||||
delete(sortedSet.dict, element.Member)
|
delete(sortedSet.dict, element.Member)
|
||||||
}
|
}
|
||||||
return int64(len(removed))
|
return int64(len(removed))
|
||||||
}
|
}
|
||||||
|
@@ -1,28 +1,28 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
func Equals(a interface{}, b interface{})bool {
|
func Equals(a interface{}, b interface{}) bool {
|
||||||
sliceA, okA := a.([]byte)
|
sliceA, okA := a.([]byte)
|
||||||
sliceB, okB := b.([]byte)
|
sliceB, okB := b.([]byte)
|
||||||
if okA && okB {
|
if okA && okB {
|
||||||
return BytesEquals(sliceA, sliceB)
|
return BytesEquals(sliceA, sliceB)
|
||||||
}
|
}
|
||||||
return a == b
|
return a == b
|
||||||
}
|
}
|
||||||
|
|
||||||
func BytesEquals(a []byte, b []byte) bool {
|
func BytesEquals(a []byte, b []byte) bool {
|
||||||
if (a == nil && b != nil) || (a != nil && b == nil) {
|
if (a == nil && b != nil) || (a != nil && b == nil) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if len(a) != len(b) {
|
if len(a) != len(b) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
size := len(a)
|
size := len(a)
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
av := a[i]
|
av := a[i]
|
||||||
bv := b[i]
|
bv := b[i]
|
||||||
if av != bv {
|
if av != bv {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@@ -2,7 +2,7 @@ package db
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"github.com/HDT3213/godis/src/config"
|
"github.com/HDT3213/godis/src/config"
|
||||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||||
List "github.com/HDT3213/godis/src/datastruct/list"
|
List "github.com/HDT3213/godis/src/datastruct/list"
|
||||||
"github.com/HDT3213/godis/src/datastruct/lock"
|
"github.com/HDT3213/godis/src/datastruct/lock"
|
||||||
@@ -29,18 +29,18 @@ func makeExpireCmd(key string, expireAt time.Time) *reply.MultiBulkReply {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func makeAofCmd(cmd string, args [][]byte) *reply.MultiBulkReply {
|
func makeAofCmd(cmd string, args [][]byte) *reply.MultiBulkReply {
|
||||||
params := make([][]byte, len(args)+1)
|
params := make([][]byte, len(args)+1)
|
||||||
copy(params[1:], args)
|
copy(params[1:], args)
|
||||||
params[0] = []byte(cmd)
|
params[0] = []byte(cmd)
|
||||||
return reply.MakeMultiBulkReply(params)
|
return reply.MakeMultiBulkReply(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
// send command to aof
|
// send command to aof
|
||||||
func (db *DB) AddAof(args *reply.MultiBulkReply) {
|
func (db *DB) AddAof(args *reply.MultiBulkReply) {
|
||||||
// aofChan == nil when loadAof
|
// aofChan == nil when loadAof
|
||||||
if config.Properties.AppendOnly && db.aofChan != nil {
|
if config.Properties.AppendOnly && db.aofChan != nil {
|
||||||
db.aofChan <- args
|
db.aofChan <- args
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen aof channel and write into file
|
// listen aof channel and write into file
|
||||||
@@ -72,12 +72,12 @@ func trim(msg []byte) string {
|
|||||||
|
|
||||||
// read aof file
|
// read aof file
|
||||||
func (db *DB) loadAof(maxBytes int) {
|
func (db *DB) loadAof(maxBytes int) {
|
||||||
// delete aofChan to prevent write again
|
// delete aofChan to prevent write again
|
||||||
aofChan := db.aofChan
|
aofChan := db.aofChan
|
||||||
db.aofChan = nil
|
db.aofChan = nil
|
||||||
defer func(aofChan chan *reply.MultiBulkReply) {
|
defer func(aofChan chan *reply.MultiBulkReply) {
|
||||||
db.aofChan = aofChan
|
db.aofChan = aofChan
|
||||||
}(aofChan)
|
}(aofChan)
|
||||||
|
|
||||||
file, err := os.Open(db.aofFilename)
|
file, err := os.Open(db.aofFilename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
366
src/db/db.go
366
src/db/db.go
@@ -1,297 +1,297 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/HDT3213/godis/src/config"
|
"github.com/HDT3213/godis/src/config"
|
||||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||||
List "github.com/HDT3213/godis/src/datastruct/list"
|
List "github.com/HDT3213/godis/src/datastruct/list"
|
||||||
"github.com/HDT3213/godis/src/datastruct/lock"
|
"github.com/HDT3213/godis/src/datastruct/lock"
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/lib/logger"
|
"github.com/HDT3213/godis/src/lib/logger"
|
||||||
"github.com/HDT3213/godis/src/pubsub"
|
"github.com/HDT3213/godis/src/pubsub"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"os"
|
"os"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DataEntity struct {
|
type DataEntity struct {
|
||||||
Data interface{}
|
Data interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dataDictSize = 1 << 16
|
dataDictSize = 1 << 16
|
||||||
ttlDictSize = 1 << 10
|
ttlDictSize = 1 << 10
|
||||||
lockerSize = 128
|
lockerSize = 128
|
||||||
aofQueueSize = 1 << 16
|
aofQueueSize = 1 << 16
|
||||||
)
|
)
|
||||||
|
|
||||||
// args don't include cmd line
|
// args don't include cmd line
|
||||||
type CmdFunc func(db *DB, args [][]byte) redis.Reply
|
type CmdFunc func(db *DB, args [][]byte) redis.Reply
|
||||||
|
|
||||||
type DB struct {
|
type DB struct {
|
||||||
// key -> DataEntity
|
// key -> DataEntity
|
||||||
Data dict.Dict
|
Data dict.Dict
|
||||||
// key -> expireTime (time.Time)
|
// key -> expireTime (time.Time)
|
||||||
TTLMap dict.Dict
|
TTLMap dict.Dict
|
||||||
// channel -> list<*client>
|
// channel -> list<*client>
|
||||||
SubMap dict.Dict
|
SubMap dict.Dict
|
||||||
|
|
||||||
// dict will ensure thread safety of its method
|
// dict will ensure thread safety of its method
|
||||||
// use this mutex for complicated command only, eg. rpush, incr ...
|
// use this mutex for complicated command only, eg. rpush, incr ...
|
||||||
Locker *lock.Locks
|
Locker *lock.Locks
|
||||||
|
|
||||||
// TimerTask interval
|
// TimerTask interval
|
||||||
interval time.Duration
|
interval time.Duration
|
||||||
|
|
||||||
stopWorld sync.WaitGroup
|
stopWorld sync.WaitGroup
|
||||||
|
|
||||||
hub *pubsub.Hub
|
hub *pubsub.Hub
|
||||||
|
|
||||||
// main goroutine send commands to aof goroutine through aofChan
|
// main goroutine send commands to aof goroutine through aofChan
|
||||||
aofChan chan *reply.MultiBulkReply
|
aofChan chan *reply.MultiBulkReply
|
||||||
aofFile *os.File
|
aofFile *os.File
|
||||||
aofFilename string
|
aofFilename string
|
||||||
|
|
||||||
aofRewriteChan chan *reply.MultiBulkReply
|
aofRewriteChan chan *reply.MultiBulkReply
|
||||||
pausingAof sync.RWMutex
|
pausingAof sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
var router = MakeRouter()
|
var router = MakeRouter()
|
||||||
|
|
||||||
func MakeDB() *DB {
|
func MakeDB() *DB {
|
||||||
db := &DB{
|
db := &DB{
|
||||||
Data: dict.MakeConcurrent(dataDictSize),
|
Data: dict.MakeConcurrent(dataDictSize),
|
||||||
TTLMap: dict.MakeConcurrent(ttlDictSize),
|
TTLMap: dict.MakeConcurrent(ttlDictSize),
|
||||||
Locker: lock.Make(lockerSize),
|
Locker: lock.Make(lockerSize),
|
||||||
interval: 5 * time.Second,
|
interval: 5 * time.Second,
|
||||||
hub: pubsub.MakeHub(),
|
hub: pubsub.MakeHub(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// aof
|
// aof
|
||||||
if config.Properties.AppendOnly {
|
if config.Properties.AppendOnly {
|
||||||
db.aofFilename = config.Properties.AppendFilename
|
db.aofFilename = config.Properties.AppendFilename
|
||||||
db.loadAof(0)
|
db.loadAof(0)
|
||||||
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
|
aofFile, err := os.OpenFile(db.aofFilename, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0600)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
} else {
|
} else {
|
||||||
db.aofFile = aofFile
|
db.aofFile = aofFile
|
||||||
db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize)
|
db.aofChan = make(chan *reply.MultiBulkReply, aofQueueSize)
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
db.handleAof()
|
db.handleAof()
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// start timer
|
// start timer
|
||||||
db.TimerTask()
|
db.TimerTask()
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Close() {
|
func (db *DB) Close() {
|
||||||
if db.aofFile != nil {
|
if db.aofFile != nil {
|
||||||
err := db.aofFile.Close()
|
err := db.aofFile.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Exec(c redis.Connection, args [][]byte) (result redis.Reply) {
|
func (db *DB) Exec(c redis.Connection, args [][]byte) (result redis.Reply) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
|
logger.Warn(fmt.Sprintf("error occurs: %v\n%s", err, string(debug.Stack())))
|
||||||
result = &reply.UnknownErrReply{}
|
result = &reply.UnknownErrReply{}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
cmd := strings.ToLower(string(args[0]))
|
cmd := strings.ToLower(string(args[0]))
|
||||||
|
|
||||||
// special commands
|
// special commands
|
||||||
if cmd == "subscribe" {
|
if cmd == "subscribe" {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return &reply.ArgNumErrReply{Cmd: "subscribe"}
|
return &reply.ArgNumErrReply{Cmd: "subscribe"}
|
||||||
}
|
}
|
||||||
return pubsub.Subscribe(db.hub, c, args[1:])
|
return pubsub.Subscribe(db.hub, c, args[1:])
|
||||||
} else if cmd == "publish" {
|
} else if cmd == "publish" {
|
||||||
return pubsub.Publish(db.hub, args[1:])
|
return pubsub.Publish(db.hub, args[1:])
|
||||||
} else if cmd == "unsubscribe" {
|
} else if cmd == "unsubscribe" {
|
||||||
return pubsub.UnSubscribe(db.hub, c, args[1:])
|
return pubsub.UnSubscribe(db.hub, c, args[1:])
|
||||||
} else if cmd == "bgrewriteaof" {
|
} else if cmd == "bgrewriteaof" {
|
||||||
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
|
// aof.go imports router.go, router.go cannot import BGRewriteAOF from aof.go
|
||||||
reply := BGRewriteAOF(db, args[1:])
|
reply := BGRewriteAOF(db, args[1:])
|
||||||
return reply
|
return reply
|
||||||
}
|
}
|
||||||
|
|
||||||
// normal commands
|
// normal commands
|
||||||
cmdFunc, ok := router[cmd]
|
cmdFunc, ok := router[cmd]
|
||||||
if !ok {
|
if !ok {
|
||||||
return reply.MakeErrReply("ERR unknown command '" + cmd + "'")
|
return reply.MakeErrReply("ERR unknown command '" + cmd + "'")
|
||||||
}
|
}
|
||||||
if len(args) > 1 {
|
if len(args) > 1 {
|
||||||
result = cmdFunc(db, args[1:])
|
result = cmdFunc(db, args[1:])
|
||||||
} else {
|
} else {
|
||||||
result = cmdFunc(db, [][]byte{})
|
result = cmdFunc(db, [][]byte{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// aof
|
// aof
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---- Data Access ----- */
|
/* ---- Data Access ----- */
|
||||||
|
|
||||||
func (db *DB) Get(key string) (*DataEntity, bool) {
|
func (db *DB) Get(key string) (*DataEntity, bool) {
|
||||||
db.stopWorld.Wait()
|
db.stopWorld.Wait()
|
||||||
|
|
||||||
raw, ok := db.Data.Get(key)
|
raw, ok := db.Data.Get(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
if db.IsExpired(key) {
|
if db.IsExpired(key) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
entity, _ := raw.(*DataEntity)
|
entity, _ := raw.(*DataEntity)
|
||||||
return entity, true
|
return entity, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Put(key string, entity *DataEntity) int {
|
func (db *DB) Put(key string, entity *DataEntity) int {
|
||||||
db.stopWorld.Wait()
|
db.stopWorld.Wait()
|
||||||
return db.Data.Put(key, entity)
|
return db.Data.Put(key, entity)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) PutIfExists(key string, entity *DataEntity) int {
|
func (db *DB) PutIfExists(key string, entity *DataEntity) int {
|
||||||
db.stopWorld.Wait()
|
db.stopWorld.Wait()
|
||||||
return db.Data.PutIfExists(key, entity)
|
return db.Data.PutIfExists(key, entity)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) PutIfAbsent(key string, entity *DataEntity) int {
|
func (db *DB) PutIfAbsent(key string, entity *DataEntity) int {
|
||||||
db.stopWorld.Wait()
|
db.stopWorld.Wait()
|
||||||
return db.Data.PutIfAbsent(key, entity)
|
return db.Data.PutIfAbsent(key, entity)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Remove(key string) {
|
func (db *DB) Remove(key string) {
|
||||||
db.stopWorld.Wait()
|
db.stopWorld.Wait()
|
||||||
db.Data.Remove(key)
|
db.Data.Remove(key)
|
||||||
db.TTLMap.Remove(key)
|
db.TTLMap.Remove(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Removes(keys ...string) (deleted int) {
|
func (db *DB) Removes(keys ...string) (deleted int) {
|
||||||
db.stopWorld.Wait()
|
db.stopWorld.Wait()
|
||||||
deleted = 0
|
deleted = 0
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
_, exists := db.Data.Get(key)
|
_, exists := db.Data.Get(key)
|
||||||
if exists {
|
if exists {
|
||||||
db.Data.Remove(key)
|
db.Data.Remove(key)
|
||||||
db.TTLMap.Remove(key)
|
db.TTLMap.Remove(key)
|
||||||
deleted++
|
deleted++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return deleted
|
return deleted
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Flush() {
|
func (db *DB) Flush() {
|
||||||
db.stopWorld.Add(1)
|
db.stopWorld.Add(1)
|
||||||
defer db.stopWorld.Done()
|
defer db.stopWorld.Done()
|
||||||
|
|
||||||
db.Data = dict.MakeConcurrent(dataDictSize)
|
db.Data = dict.MakeConcurrent(dataDictSize)
|
||||||
db.TTLMap = dict.MakeConcurrent(ttlDictSize)
|
db.TTLMap = dict.MakeConcurrent(ttlDictSize)
|
||||||
db.Locker = lock.Make(lockerSize)
|
db.Locker = lock.Make(lockerSize)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---- Lock Function ----- */
|
/* ---- Lock Function ----- */
|
||||||
|
|
||||||
func (db *DB) Lock(key string) {
|
func (db *DB) Lock(key string) {
|
||||||
db.Locker.Lock(key)
|
db.Locker.Lock(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) RLock(key string) {
|
func (db *DB) RLock(key string) {
|
||||||
db.Locker.RLock(key)
|
db.Locker.RLock(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) UnLock(key string) {
|
func (db *DB) UnLock(key string) {
|
||||||
db.Locker.UnLock(key)
|
db.Locker.UnLock(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) RUnLock(key string) {
|
func (db *DB) RUnLock(key string) {
|
||||||
db.Locker.RUnLock(key)
|
db.Locker.RUnLock(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Locks(keys ...string) {
|
func (db *DB) Locks(keys ...string) {
|
||||||
db.Locker.Locks(keys...)
|
db.Locker.Locks(keys...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) RLocks(keys ...string) {
|
func (db *DB) RLocks(keys ...string) {
|
||||||
db.Locker.RLocks(keys...)
|
db.Locker.RLocks(keys...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) UnLocks(keys ...string) {
|
func (db *DB) UnLocks(keys ...string) {
|
||||||
db.Locker.UnLocks(keys...)
|
db.Locker.UnLocks(keys...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) RUnLocks(keys ...string) {
|
func (db *DB) RUnLocks(keys ...string) {
|
||||||
db.Locker.RUnLocks(keys...)
|
db.Locker.RUnLocks(keys...)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---- TTL Functions ---- */
|
/* ---- TTL Functions ---- */
|
||||||
|
|
||||||
func (db *DB) Expire(key string, expireTime time.Time) {
|
func (db *DB) Expire(key string, expireTime time.Time) {
|
||||||
db.stopWorld.Wait()
|
db.stopWorld.Wait()
|
||||||
db.TTLMap.Put(key, expireTime)
|
db.TTLMap.Put(key, expireTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) Persist(key string) {
|
func (db *DB) Persist(key string) {
|
||||||
db.stopWorld.Wait()
|
db.stopWorld.Wait()
|
||||||
db.TTLMap.Remove(key)
|
db.TTLMap.Remove(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) IsExpired(key string) bool {
|
func (db *DB) IsExpired(key string) bool {
|
||||||
rawExpireTime, ok := db.TTLMap.Get(key)
|
rawExpireTime, ok := db.TTLMap.Get(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
expireTime, _ := rawExpireTime.(time.Time)
|
expireTime, _ := rawExpireTime.(time.Time)
|
||||||
expired := time.Now().After(expireTime)
|
expired := time.Now().After(expireTime)
|
||||||
if expired {
|
if expired {
|
||||||
db.Remove(key)
|
db.Remove(key)
|
||||||
}
|
}
|
||||||
return expired
|
return expired
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) CleanExpired() {
|
func (db *DB) CleanExpired() {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
toRemove := &List.LinkedList{}
|
toRemove := &List.LinkedList{}
|
||||||
db.TTLMap.ForEach(func(key string, val interface{}) bool {
|
db.TTLMap.ForEach(func(key string, val interface{}) bool {
|
||||||
expireTime, _ := val.(time.Time)
|
expireTime, _ := val.(time.Time)
|
||||||
if now.After(expireTime) {
|
if now.After(expireTime) {
|
||||||
// expired
|
// expired
|
||||||
db.Data.Remove(key)
|
db.Data.Remove(key)
|
||||||
toRemove.Add(key)
|
toRemove.Add(key)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
toRemove.ForEach(func(i int, val interface{}) bool {
|
toRemove.ForEach(func(i int, val interface{}) bool {
|
||||||
key, _ := val.(string)
|
key, _ := val.(string)
|
||||||
db.TTLMap.Remove(key)
|
db.TTLMap.Remove(key)
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) TimerTask() {
|
func (db *DB) TimerTask() {
|
||||||
ticker := time.NewTicker(db.interval)
|
ticker := time.NewTicker(db.interval)
|
||||||
go func() {
|
go func() {
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
db.CleanExpired()
|
db.CleanExpired()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---- Subscribe Functions ---- */
|
/* ---- Subscribe Functions ---- */
|
||||||
|
|
||||||
func (db *DB) AfterClientClose(c redis.Connection) {
|
func (db *DB) AfterClientClose(c redis.Connection) {
|
||||||
pubsub.UnsubscribeAll(db.hub, c)
|
pubsub.UnsubscribeAll(db.hub, c)
|
||||||
}
|
}
|
||||||
|
674
src/db/hash.go
674
src/db/hash.go
@@ -1,422 +1,422 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
Dict "github.com/HDT3213/godis/src/datastruct/dict"
|
Dict "github.com/HDT3213/godis/src/datastruct/dict"
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"github.com/shopspring/decimal"
|
"github.com/shopspring/decimal"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db *DB) getAsDict(key string) (Dict.Dict, reply.ErrorReply) {
|
func (db *DB) getAsDict(key string) (Dict.Dict, reply.ErrorReply) {
|
||||||
entity, exists := db.Get(key)
|
entity, exists := db.Get(key)
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
dict, ok := entity.Data.(Dict.Dict)
|
dict, ok := entity.Data.(Dict.Dict)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &reply.WrongTypeErrReply{}
|
return nil, &reply.WrongTypeErrReply{}
|
||||||
}
|
}
|
||||||
return dict, nil
|
return dict, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply reply.ErrorReply) {
|
func (db *DB) getOrInitDict(key string) (dict Dict.Dict, inited bool, errReply reply.ErrorReply) {
|
||||||
dict, errReply = db.getAsDict(key)
|
dict, errReply = db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return nil, false, errReply
|
return nil, false, errReply
|
||||||
}
|
}
|
||||||
inited = false
|
inited = false
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
dict = Dict.MakeSimple()
|
dict = Dict.MakeSimple()
|
||||||
db.Put(key, &DataEntity{
|
db.Put(key, &DataEntity{
|
||||||
Data: dict,
|
Data: dict,
|
||||||
})
|
})
|
||||||
inited = true
|
inited = true
|
||||||
}
|
}
|
||||||
return dict, inited, nil
|
return dict, inited, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HSet(db *DB, args [][]byte) redis.Reply {
|
func HSet(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hset' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
field := string(args[1])
|
field := string(args[1])
|
||||||
value := args[2]
|
value := args[2]
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
// get or init entity
|
// get or init entity
|
||||||
dict, _, errReply := db.getOrInitDict(key)
|
dict, _, errReply := db.getOrInitDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
result := dict.Put(field, value)
|
result := dict.Put(field, value)
|
||||||
db.AddAof(makeAofCmd("hset", args))
|
db.AddAof(makeAofCmd("hset", args))
|
||||||
return reply.MakeIntReply(int64(result))
|
return reply.MakeIntReply(int64(result))
|
||||||
}
|
}
|
||||||
|
|
||||||
func HSetNX(db *DB, args [][]byte) redis.Reply {
|
func HSetNX(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hsetnx' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
field := string(args[1])
|
field := string(args[1])
|
||||||
value := args[2]
|
value := args[2]
|
||||||
|
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
dict, _, errReply := db.getOrInitDict(key)
|
dict, _, errReply := db.getOrInitDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
result := dict.PutIfAbsent(field, value)
|
result := dict.PutIfAbsent(field, value)
|
||||||
if result > 0 {
|
if result > 0 {
|
||||||
db.AddAof(makeAofCmd("hsetnx", args))
|
db.AddAof(makeAofCmd("hsetnx", args))
|
||||||
|
|
||||||
}
|
}
|
||||||
return reply.MakeIntReply(int64(result))
|
return reply.MakeIntReply(int64(result))
|
||||||
}
|
}
|
||||||
|
|
||||||
func HGet(db *DB, args [][]byte) redis.Reply {
|
func HGet(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hget' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
field := string(args[1])
|
field := string(args[1])
|
||||||
|
|
||||||
// get entity
|
// get entity
|
||||||
dict, errReply := db.getAsDict(key)
|
dict, errReply := db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
return &reply.NullBulkReply{}
|
return &reply.NullBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
raw, exists := dict.Get(field)
|
raw, exists := dict.Get(field)
|
||||||
if !exists {
|
if !exists {
|
||||||
return &reply.NullBulkReply{}
|
return &reply.NullBulkReply{}
|
||||||
}
|
}
|
||||||
value, _ := raw.([]byte)
|
value, _ := raw.([]byte)
|
||||||
return reply.MakeBulkReply(value)
|
return reply.MakeBulkReply(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func HExists(db *DB, args [][]byte) redis.Reply {
|
func HExists(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hexists' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
field := string(args[1])
|
field := string(args[1])
|
||||||
|
|
||||||
// get entity
|
// get entity
|
||||||
dict, errReply := db.getAsDict(key)
|
dict, errReply := db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, exists := dict.Get(field)
|
_, exists := dict.Get(field)
|
||||||
if exists {
|
if exists {
|
||||||
return reply.MakeIntReply(1)
|
return reply.MakeIntReply(1)
|
||||||
}
|
}
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func HDel(db *DB, args [][]byte) redis.Reply {
|
func HDel(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
fields := make([]string, len(args) - 1)
|
fields := make([]string, len(args)-1)
|
||||||
fieldArgs := args[1:]
|
fieldArgs := args[1:]
|
||||||
for i, v := range fieldArgs {
|
for i, v := range fieldArgs {
|
||||||
fields[i] = string(v)
|
fields[i] = string(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
// get entity
|
// get entity
|
||||||
dict, errReply := db.getAsDict(key)
|
dict, errReply := db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
deleted := 0
|
deleted := 0
|
||||||
for _, field := range fields {
|
for _, field := range fields {
|
||||||
result := dict.Remove(field)
|
result := dict.Remove(field)
|
||||||
deleted += result
|
deleted += result
|
||||||
}
|
}
|
||||||
if dict.Len() == 0 {
|
if dict.Len() == 0 {
|
||||||
db.Remove(key)
|
db.Remove(key)
|
||||||
}
|
}
|
||||||
if deleted > 0 {
|
if deleted > 0 {
|
||||||
db.AddAof(makeAofCmd("hdel", args))
|
db.AddAof(makeAofCmd("hdel", args))
|
||||||
}
|
}
|
||||||
|
|
||||||
return reply.MakeIntReply(int64(deleted))
|
return reply.MakeIntReply(int64(deleted))
|
||||||
}
|
}
|
||||||
|
|
||||||
func HLen(db *DB, args [][]byte) redis.Reply {
|
func HLen(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hlen' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
|
|
||||||
dict, errReply := db.getAsDict(key)
|
dict, errReply := db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
return reply.MakeIntReply(int64(dict.Len()))
|
return reply.MakeIntReply(int64(dict.Len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func HMSet(db *DB, args [][]byte) redis.Reply {
|
func HMSet(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) < 3 || len(args) % 2 != 1 {
|
if len(args) < 3 || len(args)%2 != 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
size := (len(args) - 1) / 2
|
size := (len(args) - 1) / 2
|
||||||
fields := make([]string, size)
|
fields := make([]string, size)
|
||||||
values := make([][]byte, size)
|
values := make([][]byte, size)
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
fields[i] = string(args[2 * i + 1])
|
fields[i] = string(args[2*i+1])
|
||||||
values[i] = args[2 * i + 2]
|
values[i] = args[2*i+2]
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock key
|
// lock key
|
||||||
db.Locker.Lock(key)
|
db.Locker.Lock(key)
|
||||||
defer db.Locker.UnLock(key)
|
defer db.Locker.UnLock(key)
|
||||||
|
|
||||||
// get or init entity
|
// get or init entity
|
||||||
dict, _, errReply := db.getOrInitDict(key)
|
dict, _, errReply := db.getOrInitDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
// put data
|
// put data
|
||||||
for i, field := range fields {
|
for i, field := range fields {
|
||||||
value := values[i]
|
value := values[i]
|
||||||
dict.Put(field, value)
|
dict.Put(field, value)
|
||||||
}
|
}
|
||||||
db.AddAof(makeAofCmd("hmset", args))
|
db.AddAof(makeAofCmd("hmset", args))
|
||||||
return &reply.OkReply{}
|
return &reply.OkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HMGet(db *DB, args [][]byte) redis.Reply {
|
func HMGet(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hmget' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
size := len(args) - 1
|
size := len(args) - 1
|
||||||
fields := make([]string, size)
|
fields := make([]string, size)
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
fields[i] = string(args[i + 1])
|
fields[i] = string(args[i+1])
|
||||||
}
|
}
|
||||||
|
|
||||||
db.RLock(key)
|
db.RLock(key)
|
||||||
defer db.RUnLock(key)
|
defer db.RUnLock(key)
|
||||||
|
|
||||||
// get entity
|
// get entity
|
||||||
result := make([][]byte, size)
|
result := make([][]byte, size)
|
||||||
dict, errReply := db.getAsDict(key)
|
dict, errReply := db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
return reply.MakeMultiBulkReply(result)
|
return reply.MakeMultiBulkReply(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, field := range fields {
|
for i, field := range fields {
|
||||||
value, ok := dict.Get(field)
|
value, ok := dict.Get(field)
|
||||||
if !ok {
|
if !ok {
|
||||||
result[i] = nil
|
result[i] = nil
|
||||||
} else {
|
} else {
|
||||||
bytes, _ := value.([]byte)
|
bytes, _ := value.([]byte)
|
||||||
result[i] = bytes
|
result[i] = bytes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return reply.MakeMultiBulkReply(result)
|
return reply.MakeMultiBulkReply(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func HKeys(db *DB, args [][]byte) redis.Reply {
|
func HKeys(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hkeys' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
|
|
||||||
db.RLock(key)
|
db.RLock(key)
|
||||||
defer db.RUnLock(key)
|
defer db.RUnLock(key)
|
||||||
|
|
||||||
dict, errReply := db.getAsDict(key)
|
dict, errReply := db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
return &reply.EmptyMultiBulkReply{}
|
return &reply.EmptyMultiBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
fields := make([][]byte, dict.Len())
|
fields := make([][]byte, dict.Len())
|
||||||
i := 0
|
i := 0
|
||||||
dict.ForEach(func(key string, val interface{})bool {
|
dict.ForEach(func(key string, val interface{}) bool {
|
||||||
fields[i] = []byte(key)
|
fields[i] = []byte(key)
|
||||||
i++
|
i++
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return reply.MakeMultiBulkReply(fields[:i])
|
return reply.MakeMultiBulkReply(fields[:i])
|
||||||
}
|
}
|
||||||
|
|
||||||
func HVals(db *DB, args [][]byte) redis.Reply {
|
func HVals(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hvals' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
|
|
||||||
db.RLock(key)
|
db.RLock(key)
|
||||||
defer db.RUnLock(key)
|
defer db.RUnLock(key)
|
||||||
|
|
||||||
// get entity
|
// get entity
|
||||||
dict, errReply := db.getAsDict(key)
|
dict, errReply := db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
return &reply.EmptyMultiBulkReply{}
|
return &reply.EmptyMultiBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
values := make([][]byte, dict.Len())
|
values := make([][]byte, dict.Len())
|
||||||
i := 0
|
i := 0
|
||||||
dict.ForEach(func(key string, val interface{})bool {
|
dict.ForEach(func(key string, val interface{}) bool {
|
||||||
values[i], _ = val.([]byte)
|
values[i], _ = val.([]byte)
|
||||||
i++
|
i++
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return reply.MakeMultiBulkReply(values[:i])
|
return reply.MakeMultiBulkReply(values[:i])
|
||||||
}
|
}
|
||||||
|
|
||||||
func HGetAll(db *DB, args [][]byte) redis.Reply {
|
func HGetAll(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hgetAll' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
|
|
||||||
db.RLock(key)
|
db.RLock(key)
|
||||||
defer db.RUnLock(key)
|
defer db.RUnLock(key)
|
||||||
|
|
||||||
// get entity
|
// get entity
|
||||||
dict, errReply := db.getAsDict(key)
|
dict, errReply := db.getAsDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if dict == nil {
|
if dict == nil {
|
||||||
return &reply.EmptyMultiBulkReply{}
|
return &reply.EmptyMultiBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
size := dict.Len()
|
size := dict.Len()
|
||||||
result := make([][]byte, size * 2)
|
result := make([][]byte, size*2)
|
||||||
i := 0
|
i := 0
|
||||||
dict.ForEach(func(key string, val interface{})bool {
|
dict.ForEach(func(key string, val interface{}) bool {
|
||||||
result[i] = []byte(key)
|
result[i] = []byte(key)
|
||||||
i++
|
i++
|
||||||
result[i], _ = val.([]byte)
|
result[i], _ = val.([]byte)
|
||||||
i++
|
i++
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return reply.MakeMultiBulkReply(result[:i])
|
return reply.MakeMultiBulkReply(result[:i])
|
||||||
}
|
}
|
||||||
|
|
||||||
func HIncrBy(db *DB, args [][]byte) redis.Reply {
|
func HIncrBy(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrby' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
field := string(args[1])
|
field := string(args[1])
|
||||||
rawDelta := string(args[2])
|
rawDelta := string(args[2])
|
||||||
delta, err := strconv.ParseInt(rawDelta, 10, 64)
|
delta, err := strconv.ParseInt(rawDelta, 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Locker.Lock(key)
|
db.Locker.Lock(key)
|
||||||
defer db.Locker.UnLock(key)
|
defer db.Locker.UnLock(key)
|
||||||
|
|
||||||
dict, _, errReply := db.getOrInitDict(key)
|
dict, _, errReply := db.getOrInitDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
value, exists := dict.Get(field)
|
value, exists := dict.Get(field)
|
||||||
if !exists {
|
if !exists {
|
||||||
dict.Put(field, args[2])
|
dict.Put(field, args[2])
|
||||||
db.AddAof(makeAofCmd("hincrby", args))
|
db.AddAof(makeAofCmd("hincrby", args))
|
||||||
return reply.MakeBulkReply(args[2])
|
return reply.MakeBulkReply(args[2])
|
||||||
} else {
|
} else {
|
||||||
val, err := strconv.ParseInt(string(value.([]byte)), 10, 64)
|
val, err := strconv.ParseInt(string(value.([]byte)), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR hash value is not an integer")
|
return reply.MakeErrReply("ERR hash value is not an integer")
|
||||||
}
|
}
|
||||||
val += delta
|
val += delta
|
||||||
bytes := []byte(strconv.FormatInt(val, 10))
|
bytes := []byte(strconv.FormatInt(val, 10))
|
||||||
dict.Put(field, bytes)
|
dict.Put(field, bytes)
|
||||||
db.AddAof(makeAofCmd("hincrby", args))
|
db.AddAof(makeAofCmd("hincrby", args))
|
||||||
return reply.MakeBulkReply(bytes)
|
return reply.MakeBulkReply(bytes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func HIncrByFloat(db *DB, args [][]byte) redis.Reply {
|
func HIncrByFloat(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'hincrbyfloat' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
field := string(args[1])
|
field := string(args[1])
|
||||||
rawDelta := string(args[2])
|
rawDelta := string(args[2])
|
||||||
delta, err := decimal.NewFromString(rawDelta)
|
delta, err := decimal.NewFromString(rawDelta)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR value is not a valid float")
|
return reply.MakeErrReply("ERR value is not a valid float")
|
||||||
}
|
}
|
||||||
|
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
// get or init entity
|
// get or init entity
|
||||||
dict, _, errReply := db.getOrInitDict(key)
|
dict, _, errReply := db.getOrInitDict(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
value, exists := dict.Get(field)
|
value, exists := dict.Get(field)
|
||||||
if !exists {
|
if !exists {
|
||||||
dict.Put(field, args[2])
|
dict.Put(field, args[2])
|
||||||
return reply.MakeBulkReply(args[2])
|
return reply.MakeBulkReply(args[2])
|
||||||
} else {
|
} else {
|
||||||
val, err := decimal.NewFromString(string(value.([]byte)))
|
val, err := decimal.NewFromString(string(value.([]byte)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR hash value is not a float")
|
return reply.MakeErrReply("ERR hash value is not a float")
|
||||||
}
|
}
|
||||||
result := val.Add(delta)
|
result := val.Add(delta)
|
||||||
resultBytes:= []byte(result.String())
|
resultBytes := []byte(result.String())
|
||||||
dict.Put(field, resultBytes)
|
dict.Put(field, resultBytes)
|
||||||
db.AddAof(makeAofCmd("hincrbyfloat", args))
|
db.AddAof(makeAofCmd("hincrbyfloat", args))
|
||||||
return reply.MakeBulkReply(resultBytes)
|
return reply.MakeBulkReply(resultBytes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
704
src/db/list.go
704
src/db/list.go
@@ -1,439 +1,439 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
List "github.com/HDT3213/godis/src/datastruct/list"
|
List "github.com/HDT3213/godis/src/datastruct/list"
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (db *DB) getAsList(key string)(*List.LinkedList, reply.ErrorReply) {
|
func (db *DB) getAsList(key string) (*List.LinkedList, reply.ErrorReply) {
|
||||||
entity, ok := db.Get(key)
|
entity, ok := db.Get(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
bytes, ok := entity.Data.(*List.LinkedList)
|
bytes, ok := entity.Data.(*List.LinkedList)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, &reply.WrongTypeErrReply{}
|
return nil, &reply.WrongTypeErrReply{}
|
||||||
}
|
}
|
||||||
return bytes, nil
|
return bytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) getOrInitList(key string)(list *List.LinkedList, inited bool, errReply reply.ErrorReply) {
|
func (db *DB) getOrInitList(key string) (list *List.LinkedList, inited bool, errReply reply.ErrorReply) {
|
||||||
list, errReply = db.getAsList(key)
|
list, errReply = db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return nil, false, errReply
|
return nil, false, errReply
|
||||||
}
|
}
|
||||||
inited = false
|
inited = false
|
||||||
if list == nil {
|
if list == nil {
|
||||||
list = &List.LinkedList{}
|
list = &List.LinkedList{}
|
||||||
db.Put(key, &DataEntity{
|
db.Put(key, &DataEntity{
|
||||||
Data: list,
|
Data: list,
|
||||||
})
|
})
|
||||||
inited = true
|
inited = true
|
||||||
}
|
}
|
||||||
return list, inited, nil
|
return list, inited, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func LIndex(db *DB, args [][]byte) redis.Reply {
|
func LIndex(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||||
}
|
}
|
||||||
index := int(index64)
|
index := int(index64)
|
||||||
|
|
||||||
// get entity
|
// get entity
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return &reply.NullBulkReply{}
|
return &reply.NullBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
size := list.Len() // assert: size > 0
|
size := list.Len() // assert: size > 0
|
||||||
if index < -1 * size {
|
if index < -1*size {
|
||||||
return &reply.NullBulkReply{}
|
return &reply.NullBulkReply{}
|
||||||
} else if index < 0 {
|
} else if index < 0 {
|
||||||
index = size + index
|
index = size + index
|
||||||
} else if index >= size {
|
} else if index >= size {
|
||||||
return &reply.NullBulkReply{}
|
return &reply.NullBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
val, _ := list.Get(index).([]byte)
|
val, _ := list.Get(index).([]byte)
|
||||||
return reply.MakeBulkReply(val)
|
return reply.MakeBulkReply(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LLen(db *DB, args [][]byte) redis.Reply {
|
func LLen(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
|
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
size := int64(list.Len())
|
size := int64(list.Len())
|
||||||
return reply.MakeIntReply(size)
|
return reply.MakeIntReply(size)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LPop(db *DB, args [][]byte) redis.Reply {
|
func LPop(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
// get data
|
// get data
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return &reply.NullBulkReply{}
|
return &reply.NullBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
val, _ := list.Remove(0).([]byte)
|
val, _ := list.Remove(0).([]byte)
|
||||||
if list.Len() == 0 {
|
if list.Len() == 0 {
|
||||||
db.Remove(key)
|
db.Remove(key)
|
||||||
}
|
}
|
||||||
db.AddAof(makeAofCmd("lpop", args))
|
db.AddAof(makeAofCmd("lpop", args))
|
||||||
return reply.MakeBulkReply(val)
|
return reply.MakeBulkReply(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LPush(db *DB, args [][]byte) redis.Reply {
|
func LPush(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
values := args[1:]
|
values := args[1:]
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Locker.Lock(key)
|
db.Locker.Lock(key)
|
||||||
defer db.Locker.UnLock(key)
|
defer db.Locker.UnLock(key)
|
||||||
|
|
||||||
// get or init entity
|
// get or init entity
|
||||||
list, _, errReply := db.getOrInitList(key)
|
list, _, errReply := db.getOrInitList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert
|
// insert
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
list.Insert(0, value)
|
list.Insert(0, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddAof(makeAofCmd("lpush", args))
|
db.AddAof(makeAofCmd("lpush", args))
|
||||||
return reply.MakeIntReply(int64(list.Len()))
|
return reply.MakeIntReply(int64(list.Len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func LPushX(db *DB, args [][]byte) redis.Reply {
|
func LPushX(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'lpushx' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
values := args[1:]
|
values := args[1:]
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Locker.Lock(key)
|
db.Locker.Lock(key)
|
||||||
defer db.Locker.UnLock(key)
|
defer db.Locker.UnLock(key)
|
||||||
|
|
||||||
// get or init entity
|
// get or init entity
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert
|
// insert
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
list.Insert(0, value)
|
list.Insert(0, value)
|
||||||
}
|
}
|
||||||
db.AddAof(makeAofCmd("lpushx", args))
|
db.AddAof(makeAofCmd("lpushx", args))
|
||||||
return reply.MakeIntReply(int64(list.Len()))
|
return reply.MakeIntReply(int64(list.Len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func LRange(db *DB, args [][]byte) redis.Reply {
|
func LRange(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
start64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
start64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||||
}
|
}
|
||||||
start := int(start64)
|
start := int(start64)
|
||||||
stop64, err := strconv.ParseInt(string(args[2]), 10, 64)
|
stop64, err := strconv.ParseInt(string(args[2]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||||
}
|
}
|
||||||
stop := int(stop64)
|
stop := int(stop64)
|
||||||
|
|
||||||
// lock key
|
// lock key
|
||||||
db.RLock(key)
|
db.RLock(key)
|
||||||
defer db.RUnLock(key)
|
defer db.RUnLock(key)
|
||||||
|
|
||||||
// get data
|
// get data
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return &reply.EmptyMultiBulkReply{}
|
return &reply.EmptyMultiBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// compute index
|
// compute index
|
||||||
size := list.Len() // assert: size > 0
|
size := list.Len() // assert: size > 0
|
||||||
if start < -1 * size {
|
if start < -1*size {
|
||||||
start = 0
|
start = 0
|
||||||
} else if start < 0 {
|
} else if start < 0 {
|
||||||
start = size + start
|
start = size + start
|
||||||
} else if start >= size {
|
} else if start >= size {
|
||||||
return &reply.EmptyMultiBulkReply{}
|
return &reply.EmptyMultiBulkReply{}
|
||||||
}
|
}
|
||||||
if stop < -1 * size {
|
if stop < -1*size {
|
||||||
stop = 0
|
stop = 0
|
||||||
} else if stop < 0 {
|
} else if stop < 0 {
|
||||||
stop = size + stop + 1
|
stop = size + stop + 1
|
||||||
} else if stop < size {
|
} else if stop < size {
|
||||||
stop = stop + 1
|
stop = stop + 1
|
||||||
} else {
|
} else {
|
||||||
stop = size
|
stop = size
|
||||||
}
|
}
|
||||||
if stop < start {
|
if stop < start {
|
||||||
stop = start
|
stop = start
|
||||||
}
|
}
|
||||||
|
|
||||||
// assert: start in [0, size - 1], stop in [start, size]
|
// assert: start in [0, size - 1], stop in [start, size]
|
||||||
slice := list.Range(start, stop)
|
slice := list.Range(start, stop)
|
||||||
result := make([][]byte, len(slice))
|
result := make([][]byte, len(slice))
|
||||||
for i, raw := range slice {
|
for i, raw := range slice {
|
||||||
bytes, _ := raw.([]byte)
|
bytes, _ := raw.([]byte)
|
||||||
result[i] = bytes
|
result[i] = bytes
|
||||||
}
|
}
|
||||||
return reply.MakeMultiBulkReply(result)
|
return reply.MakeMultiBulkReply(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func LRem(db *DB, args [][]byte) redis.Reply {
|
func LRem(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
count64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
count64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||||
}
|
}
|
||||||
count := int(count64)
|
count := int(count64)
|
||||||
value := args[2]
|
value := args[2]
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
// get data entity
|
// get data entity
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
var removed int
|
var removed int
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
removed = list.RemoveAllByVal(value)
|
removed = list.RemoveAllByVal(value)
|
||||||
} else if count > 0 {
|
} else if count > 0 {
|
||||||
removed = list.RemoveByVal(value, count)
|
removed = list.RemoveByVal(value, count)
|
||||||
} else {
|
} else {
|
||||||
removed = list.ReverseRemoveByVal(value, -count)
|
removed = list.ReverseRemoveByVal(value, -count)
|
||||||
}
|
}
|
||||||
|
|
||||||
if list.Len() == 0 {
|
if list.Len() == 0 {
|
||||||
db.Remove(key)
|
db.Remove(key)
|
||||||
}
|
}
|
||||||
if removed > 0 {
|
if removed > 0 {
|
||||||
db.AddAof(makeAofCmd("lrem", args))
|
db.AddAof(makeAofCmd("lrem", args))
|
||||||
}
|
}
|
||||||
|
|
||||||
return reply.MakeIntReply(int64(removed))
|
return reply.MakeIntReply(int64(removed))
|
||||||
}
|
}
|
||||||
|
|
||||||
func LSet(db *DB, args [][]byte) redis.Reply {
|
func LSet(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 3 {
|
if len(args) != 3 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
index64, err := strconv.ParseInt(string(args[1]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
return reply.MakeErrReply("ERR value is not an integer or out of range")
|
||||||
}
|
}
|
||||||
index := int(index64)
|
index := int(index64)
|
||||||
value := args[2]
|
value := args[2]
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Locker.Lock(key)
|
db.Locker.Lock(key)
|
||||||
defer db.Locker.UnLock(key)
|
defer db.Locker.UnLock(key)
|
||||||
|
|
||||||
// get data
|
// get data
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return reply.MakeErrReply("ERR no such key")
|
return reply.MakeErrReply("ERR no such key")
|
||||||
}
|
}
|
||||||
|
|
||||||
size := list.Len() // assert: size > 0
|
size := list.Len() // assert: size > 0
|
||||||
if index < -1 * size {
|
if index < -1*size {
|
||||||
return reply.MakeErrReply("ERR index out of range")
|
return reply.MakeErrReply("ERR index out of range")
|
||||||
} else if index < 0 {
|
} else if index < 0 {
|
||||||
index = size + index
|
index = size + index
|
||||||
} else if index >= size {
|
} else if index >= size {
|
||||||
return reply.MakeErrReply("ERR index out of range")
|
return reply.MakeErrReply("ERR index out of range")
|
||||||
}
|
}
|
||||||
|
|
||||||
list.Set(index, value)
|
list.Set(index, value)
|
||||||
db.AddAof(makeAofCmd("lset", args))
|
db.AddAof(makeAofCmd("lset", args))
|
||||||
return &reply.OkReply{}
|
return &reply.OkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func RPop(db *DB, args [][]byte) redis.Reply {
|
func RPop(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) != 1 {
|
if len(args) != 1 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'rpop' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
// get data
|
// get data
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return &reply.NullBulkReply{}
|
return &reply.NullBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
val, _ := list.RemoveLast().([]byte)
|
val, _ := list.RemoveLast().([]byte)
|
||||||
if list.Len() == 0 {
|
if list.Len() == 0 {
|
||||||
db.Remove(key)
|
db.Remove(key)
|
||||||
}
|
}
|
||||||
db.AddAof(makeAofCmd("rpop", args))
|
db.AddAof(makeAofCmd("rpop", args))
|
||||||
return reply.MakeBulkReply(val)
|
return reply.MakeBulkReply(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RPopLPush(db *DB, args [][]byte) redis.Reply {
|
func RPopLPush(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command")
|
||||||
}
|
}
|
||||||
sourceKey := string(args[0])
|
sourceKey := string(args[0])
|
||||||
destKey := string(args[1])
|
destKey := string(args[1])
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Locks(sourceKey, destKey)
|
db.Locks(sourceKey, destKey)
|
||||||
defer db.UnLocks(sourceKey, destKey)
|
defer db.UnLocks(sourceKey, destKey)
|
||||||
|
|
||||||
// get source entity
|
// get source entity
|
||||||
sourceList, errReply := db.getAsList(sourceKey)
|
sourceList, errReply := db.getAsList(sourceKey)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if sourceList == nil {
|
if sourceList == nil {
|
||||||
return &reply.NullBulkReply{}
|
return &reply.NullBulkReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// get dest entity
|
// get dest entity
|
||||||
destList, _, errReply := db.getOrInitList(destKey)
|
destList, _, errReply := db.getOrInitList(destKey)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
// pop and push
|
// pop and push
|
||||||
val, _ := sourceList.RemoveLast().([]byte)
|
val, _ := sourceList.RemoveLast().([]byte)
|
||||||
destList.Insert(0, val)
|
destList.Insert(0, val)
|
||||||
|
|
||||||
if sourceList.Len() == 0 {
|
if sourceList.Len() == 0 {
|
||||||
db.Remove(sourceKey)
|
db.Remove(sourceKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
db.AddAof(makeAofCmd("rpoplpush", args))
|
db.AddAof(makeAofCmd("rpoplpush", args))
|
||||||
return reply.MakeBulkReply(val)
|
return reply.MakeBulkReply(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RPush(db *DB, args [][]byte) redis.Reply {
|
func RPush(db *DB, args [][]byte) redis.Reply {
|
||||||
// parse args
|
// parse args
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
values := args[1:]
|
values := args[1:]
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
// get or init entity
|
// get or init entity
|
||||||
list, _, errReply := db.getOrInitList(key)
|
list, _, errReply := db.getOrInitList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
|
|
||||||
// put list
|
// put list
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
list.Add(value)
|
list.Add(value)
|
||||||
}
|
}
|
||||||
db.AddAof(makeAofCmd("rpush", args))
|
db.AddAof(makeAofCmd("rpush", args))
|
||||||
return reply.MakeIntReply(int64(list.Len()))
|
return reply.MakeIntReply(int64(list.Len()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func RPushX(db *DB, args [][]byte) redis.Reply {
|
func RPushX(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command")
|
||||||
}
|
}
|
||||||
key := string(args[0])
|
key := string(args[0])
|
||||||
values := args[1:]
|
values := args[1:]
|
||||||
|
|
||||||
// lock
|
// lock
|
||||||
db.Lock(key)
|
db.Lock(key)
|
||||||
defer db.UnLock(key)
|
defer db.UnLock(key)
|
||||||
|
|
||||||
// get or init entity
|
// get or init entity
|
||||||
list, errReply := db.getAsList(key)
|
list, errReply := db.getAsList(key)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
return errReply
|
return errReply
|
||||||
}
|
}
|
||||||
if list == nil {
|
if list == nil {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// put list
|
// put list
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
list.Add(value)
|
list.Add(value)
|
||||||
}
|
}
|
||||||
db.AddAof(makeAofCmd("rpushx", args))
|
db.AddAof(makeAofCmd("rpushx", args))
|
||||||
|
|
||||||
return reply.MakeIntReply(int64(list.Len()))
|
return reply.MakeIntReply(int64(list.Len()))
|
||||||
}
|
}
|
||||||
|
180
src/db/router.go
180
src/db/router.go
@@ -1,102 +1,102 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
func MakeRouter()map[string]CmdFunc {
|
func MakeRouter() map[string]CmdFunc {
|
||||||
routerMap := make(map[string]CmdFunc)
|
routerMap := make(map[string]CmdFunc)
|
||||||
routerMap["ping"] = Ping
|
routerMap["ping"] = Ping
|
||||||
|
|
||||||
routerMap["del"] = Del
|
routerMap["del"] = Del
|
||||||
routerMap["expire"] = Expire
|
routerMap["expire"] = Expire
|
||||||
routerMap["expireat"] = ExpireAt
|
routerMap["expireat"] = ExpireAt
|
||||||
routerMap["pexpire"] = PExpire
|
routerMap["pexpire"] = PExpire
|
||||||
routerMap["pexpireat"] = PExpireAt
|
routerMap["pexpireat"] = PExpireAt
|
||||||
routerMap["ttl"] = TTL
|
routerMap["ttl"] = TTL
|
||||||
routerMap["pttl"] = PTTL
|
routerMap["pttl"] = PTTL
|
||||||
routerMap["persist"] = Persist
|
routerMap["persist"] = Persist
|
||||||
routerMap["exists"] = Exists
|
routerMap["exists"] = Exists
|
||||||
routerMap["type"] = Type
|
routerMap["type"] = Type
|
||||||
routerMap["rename"] = Rename
|
routerMap["rename"] = Rename
|
||||||
routerMap["renamenx"] = RenameNx
|
routerMap["renamenx"] = RenameNx
|
||||||
|
|
||||||
routerMap["set"] = Set
|
routerMap["set"] = Set
|
||||||
routerMap["setnx"] = SetNX
|
routerMap["setnx"] = SetNX
|
||||||
routerMap["setex"] = SetEX
|
routerMap["setex"] = SetEX
|
||||||
routerMap["psetex"] = PSetEX
|
routerMap["psetex"] = PSetEX
|
||||||
routerMap["mset"] = MSet
|
routerMap["mset"] = MSet
|
||||||
routerMap["mget"] = MGet
|
routerMap["mget"] = MGet
|
||||||
routerMap["msetnx"] = MSetNX
|
routerMap["msetnx"] = MSetNX
|
||||||
routerMap["get"] = Get
|
routerMap["get"] = Get
|
||||||
routerMap["getset"] = GetSet
|
routerMap["getset"] = GetSet
|
||||||
routerMap["incr"] = Incr
|
routerMap["incr"] = Incr
|
||||||
routerMap["incrby"] = IncrBy
|
routerMap["incrby"] = IncrBy
|
||||||
routerMap["incrbyfloat"] = IncrByFloat
|
routerMap["incrbyfloat"] = IncrByFloat
|
||||||
routerMap["decr"] = Decr
|
routerMap["decr"] = Decr
|
||||||
routerMap["decrby"] = DecrBy
|
routerMap["decrby"] = DecrBy
|
||||||
|
|
||||||
routerMap["lpush"] = LPush
|
routerMap["lpush"] = LPush
|
||||||
routerMap["lpushx"] = LPushX
|
routerMap["lpushx"] = LPushX
|
||||||
routerMap["rpush"] = RPush
|
routerMap["rpush"] = RPush
|
||||||
routerMap["rpushx"] = RPushX
|
routerMap["rpushx"] = RPushX
|
||||||
routerMap["lpop"] = LPop
|
routerMap["lpop"] = LPop
|
||||||
routerMap["rpop"] = RPop
|
routerMap["rpop"] = RPop
|
||||||
routerMap["rpoplpush"] = RPopLPush
|
routerMap["rpoplpush"] = RPopLPush
|
||||||
routerMap["lrem"] = LRem
|
routerMap["lrem"] = LRem
|
||||||
routerMap["llen"] = LLen
|
routerMap["llen"] = LLen
|
||||||
routerMap["lindex"] = LIndex
|
routerMap["lindex"] = LIndex
|
||||||
routerMap["lset"] = LSet
|
routerMap["lset"] = LSet
|
||||||
routerMap["lrange"] = LRange
|
routerMap["lrange"] = LRange
|
||||||
|
|
||||||
routerMap["hset"] = HSet
|
routerMap["hset"] = HSet
|
||||||
routerMap["hsetnx"] = HSetNX
|
routerMap["hsetnx"] = HSetNX
|
||||||
routerMap["hget"] = HGet
|
routerMap["hget"] = HGet
|
||||||
routerMap["hexists"] = HExists
|
routerMap["hexists"] = HExists
|
||||||
routerMap["hdel"] = HDel
|
routerMap["hdel"] = HDel
|
||||||
routerMap["hlen"] = HLen
|
routerMap["hlen"] = HLen
|
||||||
routerMap["hmget"] = HMGet
|
routerMap["hmget"] = HMGet
|
||||||
routerMap["hmset"] = HMSet
|
routerMap["hmset"] = HMSet
|
||||||
routerMap["hkeys"] = HKeys
|
routerMap["hkeys"] = HKeys
|
||||||
routerMap["hvals"] = HVals
|
routerMap["hvals"] = HVals
|
||||||
routerMap["hgetall"] = HGetAll
|
routerMap["hgetall"] = HGetAll
|
||||||
routerMap["hincrby"] = HIncrBy
|
routerMap["hincrby"] = HIncrBy
|
||||||
routerMap["hincrbyfloat"] = HIncrByFloat
|
routerMap["hincrbyfloat"] = HIncrByFloat
|
||||||
|
|
||||||
routerMap["sadd"] = SAdd
|
routerMap["sadd"] = SAdd
|
||||||
routerMap["sismember"] = SIsMember
|
routerMap["sismember"] = SIsMember
|
||||||
routerMap["srem"] = SRem
|
routerMap["srem"] = SRem
|
||||||
routerMap["scard"] = SCard
|
routerMap["scard"] = SCard
|
||||||
routerMap["smembers"] = SMembers
|
routerMap["smembers"] = SMembers
|
||||||
routerMap["sinter"] = SInter
|
routerMap["sinter"] = SInter
|
||||||
routerMap["sinterstore"] = SInterStore
|
routerMap["sinterstore"] = SInterStore
|
||||||
routerMap["sunion"] = SUnion
|
routerMap["sunion"] = SUnion
|
||||||
routerMap["sunionstore"] = SUnionStore
|
routerMap["sunionstore"] = SUnionStore
|
||||||
routerMap["sdiff"] = SDiff
|
routerMap["sdiff"] = SDiff
|
||||||
routerMap["sdiffstore"] = SDiffStore
|
routerMap["sdiffstore"] = SDiffStore
|
||||||
routerMap["srandmember"] = SRandMember
|
routerMap["srandmember"] = SRandMember
|
||||||
|
|
||||||
routerMap["zadd"] = ZAdd
|
routerMap["zadd"] = ZAdd
|
||||||
routerMap["zscore"] = ZScore
|
routerMap["zscore"] = ZScore
|
||||||
routerMap["zincrby"] = ZIncrBy
|
routerMap["zincrby"] = ZIncrBy
|
||||||
routerMap["zrank"] = ZRank
|
routerMap["zrank"] = ZRank
|
||||||
routerMap["zcount"] = ZCount
|
routerMap["zcount"] = ZCount
|
||||||
routerMap["zrevrank"] = ZRevRank
|
routerMap["zrevrank"] = ZRevRank
|
||||||
routerMap["zcard"] = ZCard
|
routerMap["zcard"] = ZCard
|
||||||
routerMap["zrange"] = ZRange
|
routerMap["zrange"] = ZRange
|
||||||
routerMap["zrevrange"] = ZRevRange
|
routerMap["zrevrange"] = ZRevRange
|
||||||
routerMap["zrangebyscore"] = ZRangeByScore
|
routerMap["zrangebyscore"] = ZRangeByScore
|
||||||
routerMap["zrevrangebyscore"] = ZRevRangeByScore
|
routerMap["zrevrangebyscore"] = ZRevRangeByScore
|
||||||
routerMap["zrem"] = ZRem
|
routerMap["zrem"] = ZRem
|
||||||
routerMap["zremrangebyscore"] = ZRemRangeByScore
|
routerMap["zremrangebyscore"] = ZRemRangeByScore
|
||||||
routerMap["zremrangebyrank"] = ZRemRangeByRank
|
routerMap["zremrangebyrank"] = ZRemRangeByRank
|
||||||
|
|
||||||
routerMap["geoadd"] = GeoAdd
|
routerMap["geoadd"] = GeoAdd
|
||||||
routerMap["geopos"] = GeoPos
|
routerMap["geopos"] = GeoPos
|
||||||
routerMap["geodist"] = GeoDist
|
routerMap["geodist"] = GeoDist
|
||||||
routerMap["geohash"] = GeoHash
|
routerMap["geohash"] = GeoHash
|
||||||
routerMap["georadius"] = GeoRadius
|
routerMap["georadius"] = GeoRadius
|
||||||
routerMap["georadiusbymember"] = GeoRadiusByMember
|
routerMap["georadiusbymember"] = GeoRadiusByMember
|
||||||
|
|
||||||
routerMap["flushdb"] = FlushDB
|
routerMap["flushdb"] = FlushDB
|
||||||
routerMap["flushall"] = FlushAll
|
routerMap["flushall"] = FlushAll
|
||||||
routerMap["keys"] = Keys
|
routerMap["keys"] = Keys
|
||||||
|
|
||||||
return routerMap
|
return routerMap
|
||||||
}
|
}
|
||||||
|
@@ -1,16 +1,16 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Ping(db *DB, args [][]byte) redis.Reply {
|
func Ping(db *DB, args [][]byte) redis.Reply {
|
||||||
if len(args) == 0 {
|
if len(args) == 0 {
|
||||||
return &reply.PongReply{}
|
return &reply.PongReply{}
|
||||||
} else if len(args) == 1 {
|
} else if len(args) == 1 {
|
||||||
return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
|
return reply.MakeStatusReply("\"" + string(args[0]) + "\"")
|
||||||
} else {
|
} else {
|
||||||
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
|
return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
1010
src/db/sortedset.go
1010
src/db/sortedset.go
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@ package db
|
|||||||
import "github.com/HDT3213/godis/src/interface/redis"
|
import "github.com/HDT3213/godis/src/interface/redis"
|
||||||
|
|
||||||
type DB interface {
|
type DB interface {
|
||||||
Exec(client redis.Connection, args [][]byte) redis.Reply
|
Exec(client redis.Connection, args [][]byte) redis.Reply
|
||||||
AfterClientClose(c redis.Connection)
|
AfterClientClose(c redis.Connection)
|
||||||
Close()
|
Close()
|
||||||
}
|
}
|
||||||
|
@@ -1,11 +1,11 @@
|
|||||||
package redis
|
package redis
|
||||||
|
|
||||||
type Connection interface {
|
type Connection interface {
|
||||||
Write([]byte) error
|
Write([]byte) error
|
||||||
|
|
||||||
// client should keep its subscribing channels
|
// client should keep its subscribing channels
|
||||||
SubsChannel(channel string)
|
SubsChannel(channel string)
|
||||||
UnSubsChannel(channel string)
|
UnSubsChannel(channel string)
|
||||||
SubsCount()int
|
SubsCount() int
|
||||||
GetChannels()[]string
|
GetChannels() []string
|
||||||
}
|
}
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
package redis
|
package redis
|
||||||
|
|
||||||
type Reply interface {
|
type Reply interface {
|
||||||
ToBytes()[]byte
|
ToBytes() []byte
|
||||||
}
|
}
|
||||||
|
@@ -1,13 +1,13 @@
|
|||||||
package tcp
|
package tcp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"context"
|
||||||
"context"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HandleFunc func(ctx context.Context, conn net.Conn)
|
type HandleFunc func(ctx context.Context, conn net.Conn)
|
||||||
|
|
||||||
type Handler interface {
|
type Handler interface {
|
||||||
Handle(ctx context.Context, conn net.Conn)
|
Handle(ctx context.Context, conn net.Conn)
|
||||||
Close()error
|
Close() error
|
||||||
}
|
}
|
||||||
|
@@ -1,80 +1,80 @@
|
|||||||
package consistenthash
|
package consistenthash
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"hash/crc32"
|
"hash/crc32"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type HashFunc func(data []byte) uint32
|
type HashFunc func(data []byte) uint32
|
||||||
|
|
||||||
type Map struct {
|
type Map struct {
|
||||||
hashFunc HashFunc
|
hashFunc HashFunc
|
||||||
replicas int
|
replicas int
|
||||||
keys []int // sorted
|
keys []int // sorted
|
||||||
hashMap map[int]string
|
hashMap map[int]string
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(replicas int, fn HashFunc) *Map {
|
func New(replicas int, fn HashFunc) *Map {
|
||||||
m := &Map{
|
m := &Map{
|
||||||
replicas: replicas,
|
replicas: replicas,
|
||||||
hashFunc: fn,
|
hashFunc: fn,
|
||||||
hashMap: make(map[int]string),
|
hashMap: make(map[int]string),
|
||||||
}
|
}
|
||||||
if m.hashFunc == nil {
|
if m.hashFunc == nil {
|
||||||
m.hashFunc = crc32.ChecksumIEEE
|
m.hashFunc = crc32.ChecksumIEEE
|
||||||
}
|
}
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map) IsEmpty() bool {
|
func (m *Map) IsEmpty() bool {
|
||||||
return len(m.keys) == 0
|
return len(m.keys) == 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map) Add(keys ...string) {
|
func (m *Map) Add(keys ...string) {
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for i := 0; i < m.replicas; i++ {
|
for i := 0; i < m.replicas; i++ {
|
||||||
hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key)))
|
hash := int(m.hashFunc([]byte(strconv.Itoa(i) + key)))
|
||||||
m.keys = append(m.keys, hash)
|
m.keys = append(m.keys, hash)
|
||||||
m.hashMap[hash] = key
|
m.hashMap[hash] = key
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sort.Ints(m.keys)
|
sort.Ints(m.keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
// support hash tag
|
// support hash tag
|
||||||
func getPartitionKey(key string) string {
|
func getPartitionKey(key string) string {
|
||||||
beg := strings.Index(key, "{")
|
beg := strings.Index(key, "{")
|
||||||
if beg == -1 {
|
if beg == -1 {
|
||||||
return key
|
return key
|
||||||
}
|
}
|
||||||
end := strings.Index(key, "}")
|
end := strings.Index(key, "}")
|
||||||
if end == -1 || end == beg+1 {
|
if end == -1 || end == beg+1 {
|
||||||
return key
|
return key
|
||||||
}
|
}
|
||||||
return key[beg+1 : end]
|
return key[beg+1 : end]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets the closest item in the hash to the provided key.
|
// Get gets the closest item in the hash to the provided key.
|
||||||
func (m *Map) Get(key string) string {
|
func (m *Map) Get(key string) string {
|
||||||
if m.IsEmpty() {
|
if m.IsEmpty() {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
partitionKey := getPartitionKey(key)
|
partitionKey := getPartitionKey(key)
|
||||||
hash := int(m.hashFunc([]byte(partitionKey)))
|
hash := int(m.hashFunc([]byte(partitionKey)))
|
||||||
|
|
||||||
// Binary search for appropriate replica.
|
// Binary search for appropriate replica.
|
||||||
idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash })
|
idx := sort.Search(len(m.keys), func(i int) bool { return m.keys[i] >= hash })
|
||||||
|
|
||||||
// Means we have cycled back to the first replica.
|
// Means we have cycled back to the first replica.
|
||||||
if idx == len(m.keys) {
|
if idx == len(m.keys) {
|
||||||
idx = 0
|
idx = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.hashMap[m.keys[idx]]
|
return m.hashMap[m.keys[idx]]
|
||||||
}
|
}
|
||||||
|
@@ -1,78 +1,78 @@
|
|||||||
package files
|
package files
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"mime/multipart"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"path"
|
"mime/multipart"
|
||||||
"os"
|
"os"
|
||||||
"fmt"
|
"path"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetSize(f multipart.File) (int, error) {
|
func GetSize(f multipart.File) (int, error) {
|
||||||
content, err := ioutil.ReadAll(f)
|
content, err := ioutil.ReadAll(f)
|
||||||
|
|
||||||
return len(content), err
|
return len(content), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetExt(fileName string) string {
|
func GetExt(fileName string) string {
|
||||||
return path.Ext(fileName)
|
return path.Ext(fileName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckNotExist(src string) bool {
|
func CheckNotExist(src string) bool {
|
||||||
_, err := os.Stat(src)
|
_, err := os.Stat(src)
|
||||||
|
|
||||||
return os.IsNotExist(err)
|
return os.IsNotExist(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckPermission(src string) bool {
|
func CheckPermission(src string) bool {
|
||||||
_, err := os.Stat(src)
|
_, err := os.Stat(src)
|
||||||
|
|
||||||
return os.IsPermission(err)
|
return os.IsPermission(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsNotExistMkDir(src string) error {
|
func IsNotExistMkDir(src string) error {
|
||||||
if notExist := CheckNotExist(src); notExist == true {
|
if notExist := CheckNotExist(src); notExist == true {
|
||||||
if err := MkDir(src); err != nil {
|
if err := MkDir(src); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MkDir(src string) error {
|
func MkDir(src string) error {
|
||||||
err := os.MkdirAll(src, os.ModePerm)
|
err := os.MkdirAll(src, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Open(name string, flag int, perm os.FileMode) (*os.File, error) {
|
func Open(name string, flag int, perm os.FileMode) (*os.File, error) {
|
||||||
f, err := os.OpenFile(name, flag, perm)
|
f, err := os.OpenFile(name, flag, perm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustOpen(fileName, dir string) (*os.File, error) {
|
func MustOpen(fileName, dir string) (*os.File, error) {
|
||||||
perm := CheckPermission(dir)
|
perm := CheckPermission(dir)
|
||||||
if perm == true {
|
if perm == true {
|
||||||
return nil, fmt.Errorf("permission denied dir: %s", dir)
|
return nil, fmt.Errorf("permission denied dir: %s", dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := IsNotExistMkDir(dir)
|
err := IsNotExistMkDir(dir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
|
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := Open(dir + string(os.PathSeparator) + fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
|
f, err := Open(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("fail to open file, err: %s", err)
|
return nil, fmt.Errorf("fail to open file, err: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
@@ -1,9 +1,9 @@
|
|||||||
package geohash
|
package geohash
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
)
|
)
|
||||||
|
|
||||||
var bits = []uint8{128, 64, 32, 16, 8, 4, 2, 1}
|
var bits = []uint8{128, 64, 32, 16, 8, 4, 2, 1}
|
||||||
@@ -13,95 +13,95 @@ const defaultBitSize = 64 // 32 bits for latitude, another 32 bits for longitude
|
|||||||
|
|
||||||
// return: geohash, box
|
// return: geohash, box
|
||||||
func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64) {
|
func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64) {
|
||||||
box := [2][2]float64{
|
box := [2][2]float64{
|
||||||
{-180, 180}, // lng
|
{-180, 180}, // lng
|
||||||
{-90, 90}, // lat
|
{-90, 90}, // lat
|
||||||
}
|
}
|
||||||
pos := [2]float64{longitude, latitude}
|
pos := [2]float64{longitude, latitude}
|
||||||
hash := &bytes.Buffer{}
|
hash := &bytes.Buffer{}
|
||||||
bit := 0
|
bit := 0
|
||||||
var precision uint = 0
|
var precision uint = 0
|
||||||
code := uint8(0)
|
code := uint8(0)
|
||||||
for precision < bitSize {
|
for precision < bitSize {
|
||||||
for direction, val := range pos {
|
for direction, val := range pos {
|
||||||
mid := (box[direction][0] + box[direction][1]) / 2
|
mid := (box[direction][0] + box[direction][1]) / 2
|
||||||
if val < mid {
|
if val < mid {
|
||||||
box[direction][1] = mid
|
box[direction][1] = mid
|
||||||
} else {
|
} else {
|
||||||
box[direction][0] = mid
|
box[direction][0] = mid
|
||||||
code |= bits[bit]
|
code |= bits[bit]
|
||||||
}
|
}
|
||||||
bit++
|
bit++
|
||||||
if bit == 8 {
|
if bit == 8 {
|
||||||
hash.WriteByte(code)
|
hash.WriteByte(code)
|
||||||
bit = 0
|
bit = 0
|
||||||
code = 0
|
code = 0
|
||||||
}
|
}
|
||||||
precision++
|
precision++
|
||||||
if precision == bitSize {
|
if precision == bitSize {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// precision%8 > 0
|
// precision%8 > 0
|
||||||
if code > 0 {
|
if code > 0 {
|
||||||
hash.WriteByte(code)
|
hash.WriteByte(code)
|
||||||
}
|
}
|
||||||
return hash.Bytes(), box
|
return hash.Bytes(), box
|
||||||
}
|
}
|
||||||
|
|
||||||
func Encode(latitude, longitude float64) uint64 {
|
func Encode(latitude, longitude float64) uint64 {
|
||||||
buf, _ := encode0(latitude, longitude, defaultBitSize)
|
buf, _ := encode0(latitude, longitude, defaultBitSize)
|
||||||
return binary.BigEndian.Uint64(buf)
|
return binary.BigEndian.Uint64(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decode0(hash []byte) [][]float64 {
|
func decode0(hash []byte) [][]float64 {
|
||||||
box := [][]float64{
|
box := [][]float64{
|
||||||
{-180, 180},
|
{-180, 180},
|
||||||
{-90, 90},
|
{-90, 90},
|
||||||
}
|
}
|
||||||
direction := 0
|
direction := 0
|
||||||
for i := 0; i < len(hash); i++ {
|
for i := 0; i < len(hash); i++ {
|
||||||
code := hash[i]
|
code := hash[i]
|
||||||
for j := 0; j < len(bits); j++ {
|
for j := 0; j < len(bits); j++ {
|
||||||
mid := (box[direction][0] + box[direction][1]) / 2
|
mid := (box[direction][0] + box[direction][1]) / 2
|
||||||
mask := bits[j]
|
mask := bits[j]
|
||||||
if mask&code > 0 {
|
if mask&code > 0 {
|
||||||
box[direction][0] = mid
|
box[direction][0] = mid
|
||||||
} else {
|
} else {
|
||||||
box[direction][1] = mid
|
box[direction][1] = mid
|
||||||
}
|
}
|
||||||
direction = (direction + 1) % 2
|
direction = (direction + 1) % 2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return box
|
return box
|
||||||
}
|
}
|
||||||
|
|
||||||
func Decode(code uint64) (float64, float64) {
|
func Decode(code uint64) (float64, float64) {
|
||||||
buf := make([]byte, 8)
|
buf := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(buf, code)
|
binary.BigEndian.PutUint64(buf, code)
|
||||||
box := decode0(buf)
|
box := decode0(buf)
|
||||||
lng := float64(box[0][0]+box[0][1]) / 2
|
lng := float64(box[0][0]+box[0][1]) / 2
|
||||||
lat := float64(box[1][0]+box[1][1]) / 2
|
lat := float64(box[1][0]+box[1][1]) / 2
|
||||||
return lat, lng
|
return lat, lng
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToString(buf []byte) string {
|
func ToString(buf []byte) string {
|
||||||
return enc.EncodeToString(buf)
|
return enc.EncodeToString(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToInt(buf []byte) uint64 {
|
func ToInt(buf []byte) uint64 {
|
||||||
// padding
|
// padding
|
||||||
if len(buf) < 8 {
|
if len(buf) < 8 {
|
||||||
buf2 := make([]byte, 8)
|
buf2 := make([]byte, 8)
|
||||||
copy(buf2, buf)
|
copy(buf2, buf)
|
||||||
return binary.BigEndian.Uint64(buf2)
|
return binary.BigEndian.Uint64(buf2)
|
||||||
}
|
}
|
||||||
return binary.BigEndian.Uint64(buf)
|
return binary.BigEndian.Uint64(buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
func FromInt(code uint64) []byte {
|
func FromInt(code uint64) []byte {
|
||||||
buf := make([]byte, 8)
|
buf := make([]byte, 8)
|
||||||
binary.BigEndian.PutUint64(buf, code)
|
binary.BigEndian.PutUint64(buf, code)
|
||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
@@ -1,39 +1,39 @@
|
|||||||
package geohash
|
package geohash
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestToRange(t *testing.T) {
|
func TestToRange(t *testing.T) {
|
||||||
neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}
|
neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}
|
||||||
range_ := ToRange(neighbor, 36)
|
range_ := ToRange(neighbor, 36)
|
||||||
expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00})
|
expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00})
|
||||||
expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00})
|
expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00})
|
||||||
if expectedLower != range_[0] {
|
if expectedLower != range_[0] {
|
||||||
t.Error("incorrect lower")
|
t.Error("incorrect lower")
|
||||||
}
|
}
|
||||||
if expectedUpper != range_[1] {
|
if expectedUpper != range_[1] {
|
||||||
t.Error("incorrect upper")
|
t.Error("incorrect upper")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncode(t *testing.T) {
|
func TestEncode(t *testing.T) {
|
||||||
lat0 := 48.669
|
lat0 := 48.669
|
||||||
lng0 := -4.32913
|
lng0 := -4.32913
|
||||||
hash := Encode(lat0, lng0)
|
hash := Encode(lat0, lng0)
|
||||||
str := ToString(FromInt(hash))
|
str := ToString(FromInt(hash))
|
||||||
if str != "gbsuv7zt7zntw" {
|
if str != "gbsuv7zt7zntw" {
|
||||||
t.Error("encode error")
|
t.Error("encode error")
|
||||||
}
|
}
|
||||||
lat, lng := Decode(hash)
|
lat, lng := Decode(hash)
|
||||||
if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 {
|
if math.Abs(lat-lat0) > 1e-6 || math.Abs(lng-lng0) > 1e-6 {
|
||||||
t.Error("decode error")
|
t.Error("decode error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetNeighbours(t *testing.T) {
|
func TestGetNeighbours(t *testing.T) {
|
||||||
ranges := GetNeighbours(90, 180, 630*1000)
|
ranges := GetNeighbours(90, 180, 630*1000)
|
||||||
fmt.Printf("%#v", ranges)
|
fmt.Printf("%#v", ranges)
|
||||||
}
|
}
|
||||||
|
@@ -3,134 +3,134 @@ package geohash
|
|||||||
import "math"
|
import "math"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DR = math.Pi / 180.0
|
DR = math.Pi / 180.0
|
||||||
EarthRadius = 6372797.560856
|
EarthRadius = 6372797.560856
|
||||||
MercatorMax = 20037726.37 // pi * EarthRadius
|
MercatorMax = 20037726.37 // pi * EarthRadius
|
||||||
MercatorMin = -20037726.37
|
MercatorMin = -20037726.37
|
||||||
)
|
)
|
||||||
|
|
||||||
func degRad(ang float64) float64 {
|
func degRad(ang float64) float64 {
|
||||||
return ang * DR
|
return ang * DR
|
||||||
}
|
}
|
||||||
|
|
||||||
func radDeg(ang float64) float64 {
|
func radDeg(ang float64) float64 {
|
||||||
return ang / DR
|
return ang / DR
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBoundingBox(latitude float64, longitude float64, radiusMeters float64) (
|
func getBoundingBox(latitude float64, longitude float64, radiusMeters float64) (
|
||||||
minLat, maxLat, minLng, maxLng float64) {
|
minLat, maxLat, minLng, maxLng float64) {
|
||||||
minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
|
minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
|
||||||
if minLng < -180 {
|
if minLng < -180 {
|
||||||
minLng = -180
|
minLng = -180
|
||||||
}
|
}
|
||||||
maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
|
maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
|
||||||
if maxLng > 180 {
|
if maxLng > 180 {
|
||||||
maxLng = 180
|
maxLng = 180
|
||||||
}
|
}
|
||||||
minLat = latitude - radDeg(radiusMeters/EarthRadius)
|
minLat = latitude - radDeg(radiusMeters/EarthRadius)
|
||||||
if minLat < -90 {
|
if minLat < -90 {
|
||||||
minLat = -90
|
minLat = -90
|
||||||
}
|
}
|
||||||
maxLat = latitude + radDeg(radiusMeters/EarthRadius)
|
maxLat = latitude + radDeg(radiusMeters/EarthRadius)
|
||||||
if maxLat > 90 {
|
if maxLat > 90 {
|
||||||
maxLat = 90
|
maxLat = 90
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint {
|
func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint {
|
||||||
if radiusMeters == 0 {
|
if radiusMeters == 0 {
|
||||||
return defaultBitSize - 1
|
return defaultBitSize - 1
|
||||||
}
|
}
|
||||||
var precision uint = 1
|
var precision uint = 1
|
||||||
for radiusMeters < MercatorMax {
|
for radiusMeters < MercatorMax {
|
||||||
radiusMeters *= 2
|
radiusMeters *= 2
|
||||||
precision++
|
precision++
|
||||||
}
|
}
|
||||||
/* Make sure range is included in most of the base cases. */
|
/* Make sure range is included in most of the base cases. */
|
||||||
precision -= 2
|
precision -= 2
|
||||||
if latitude > 66 || latitude < -66 {
|
if latitude > 66 || latitude < -66 {
|
||||||
precision--
|
precision--
|
||||||
if latitude > 80 || latitude < -80 {
|
if latitude > 80 || latitude < -80 {
|
||||||
precision--
|
precision--
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if precision < 1 {
|
if precision < 1 {
|
||||||
precision = 1
|
precision = 1
|
||||||
}
|
}
|
||||||
if precision > 32 {
|
if precision > 32 {
|
||||||
precision = 32
|
precision = 32
|
||||||
}
|
}
|
||||||
return precision*2 - 1
|
return precision*2 - 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 {
|
func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 {
|
||||||
radLat1 := degRad(latitude1)
|
radLat1 := degRad(latitude1)
|
||||||
radLat2 := degRad(latitude2)
|
radLat2 := degRad(latitude2)
|
||||||
a := radLat1 - radLat2
|
a := radLat1 - radLat2
|
||||||
b := degRad(longitude1) - degRad(longitude2)
|
b := degRad(longitude1) - degRad(longitude2)
|
||||||
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2) +
|
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2)+
|
||||||
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
|
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToRange(scope []byte, precision uint) [2]uint64 {
|
func ToRange(scope []byte, precision uint) [2]uint64 {
|
||||||
lower := ToInt(scope)
|
lower := ToInt(scope)
|
||||||
radius := uint64(1 << (64 - precision))
|
radius := uint64(1 << (64 - precision))
|
||||||
upper := lower + radius
|
upper := lower + radius
|
||||||
return [2]uint64{lower, upper}
|
return [2]uint64{lower, upper}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureValidLat(lat float64) float64 {
|
func ensureValidLat(lat float64) float64 {
|
||||||
if lat > 90 {
|
if lat > 90 {
|
||||||
return 90
|
return 90
|
||||||
}
|
}
|
||||||
if lat < -90 {
|
if lat < -90 {
|
||||||
return -90
|
return -90
|
||||||
}
|
}
|
||||||
return lat
|
return lat
|
||||||
}
|
}
|
||||||
|
|
||||||
func ensureValidLng(lng float64) float64 {
|
func ensureValidLng(lng float64) float64 {
|
||||||
if lng > 180 {
|
if lng > 180 {
|
||||||
return -360 + lng
|
return -360 + lng
|
||||||
}
|
}
|
||||||
if lng < -180 {
|
if lng < -180 {
|
||||||
return 360 + lng
|
return 360 + lng
|
||||||
}
|
}
|
||||||
return lng
|
return lng
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 {
|
func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 {
|
||||||
precision := estimatePrecisionByRadius(radiusMeters, latitude)
|
precision := estimatePrecisionByRadius(radiusMeters, latitude)
|
||||||
|
|
||||||
center, box := encode0(latitude, longitude, precision)
|
center, box := encode0(latitude, longitude, precision)
|
||||||
height := box[0][1] - box[0][0]
|
height := box[0][1] - box[0][0]
|
||||||
width := box[1][1] - box[1][0]
|
width := box[1][1] - box[1][0]
|
||||||
centerLng := (box[0][1] + box[0][0]) / 2
|
centerLng := (box[0][1] + box[0][0]) / 2
|
||||||
centerLat := (box[1][1] + box[1][0]) / 2
|
centerLat := (box[1][1] + box[1][0]) / 2
|
||||||
maxLat := ensureValidLat(centerLat + height)
|
maxLat := ensureValidLat(centerLat + height)
|
||||||
minLat := ensureValidLat(centerLat - height)
|
minLat := ensureValidLat(centerLat - height)
|
||||||
maxLng := ensureValidLng(centerLng + width)
|
maxLng := ensureValidLng(centerLng + width)
|
||||||
minLng := ensureValidLng(centerLng - width)
|
minLng := ensureValidLng(centerLng - width)
|
||||||
|
|
||||||
var result [10][2]uint64
|
var result [10][2]uint64
|
||||||
leftUpper, _ := encode0(maxLat, minLng, precision)
|
leftUpper, _ := encode0(maxLat, minLng, precision)
|
||||||
result[1] = ToRange(leftUpper, precision)
|
result[1] = ToRange(leftUpper, precision)
|
||||||
upper, _ := encode0(maxLat, centerLng, precision)
|
upper, _ := encode0(maxLat, centerLng, precision)
|
||||||
result[2] = ToRange(upper, precision)
|
result[2] = ToRange(upper, precision)
|
||||||
rightUpper, _ := encode0(maxLat, maxLng, precision)
|
rightUpper, _ := encode0(maxLat, maxLng, precision)
|
||||||
result[3] = ToRange(rightUpper, precision)
|
result[3] = ToRange(rightUpper, precision)
|
||||||
left, _ := encode0(centerLat, minLng, precision)
|
left, _ := encode0(centerLat, minLng, precision)
|
||||||
result[4] = ToRange(left, precision)
|
result[4] = ToRange(left, precision)
|
||||||
result[5] = ToRange(center, precision)
|
result[5] = ToRange(center, precision)
|
||||||
right, _ := encode0(centerLat, maxLng, precision)
|
right, _ := encode0(centerLat, maxLng, precision)
|
||||||
result[6] = ToRange(right, precision)
|
result[6] = ToRange(right, precision)
|
||||||
leftDown, _ := encode0(minLat, minLng, precision)
|
leftDown, _ := encode0(minLat, minLng, precision)
|
||||||
result[7] = ToRange(leftDown, precision)
|
result[7] = ToRange(leftDown, precision)
|
||||||
down, _ := encode0(minLat, centerLng, precision)
|
down, _ := encode0(minLat, centerLng, precision)
|
||||||
result[8] = ToRange(down, precision)
|
result[8] = ToRange(down, precision)
|
||||||
rightDown, _ := encode0(minLat, maxLng, precision)
|
rightDown, _ := encode0(minLat, maxLng, precision)
|
||||||
result[9] = ToRange(rightDown, precision)
|
result[9] = ToRange(rightDown, precision)
|
||||||
|
|
||||||
return result[1:]
|
return result[1:]
|
||||||
}
|
}
|
||||||
|
@@ -1,21 +1,21 @@
|
|||||||
package gob
|
package gob
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Marshal(obj interface{}) ([]byte, error) {
|
func Marshal(obj interface{}) ([]byte, error) {
|
||||||
buf := new(bytes.Buffer)
|
buf := new(bytes.Buffer)
|
||||||
enc := gob.NewEncoder(buf)
|
enc := gob.NewEncoder(buf)
|
||||||
err := enc.Encode(obj)
|
err := enc.Encode(obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnMarshal(data []byte, obj interface{}) error {
|
func UnMarshal(data []byte, obj interface{}) error {
|
||||||
dec := gob.NewDecoder(bytes.NewBuffer(data))
|
dec := gob.NewDecoder(bytes.NewBuffer(data))
|
||||||
return dec.Decode(obj)
|
return dec.Decode(obj)
|
||||||
}
|
}
|
||||||
|
@@ -4,16 +4,14 @@ import "sync/atomic"
|
|||||||
|
|
||||||
type AtomicBool uint32
|
type AtomicBool uint32
|
||||||
|
|
||||||
func (b *AtomicBool)Get()bool {
|
func (b *AtomicBool) Get() bool {
|
||||||
return atomic.LoadUint32((*uint32)(b)) != 0
|
return atomic.LoadUint32((*uint32)(b)) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *AtomicBool)Set(v bool) {
|
func (b *AtomicBool) Set(v bool) {
|
||||||
if v {
|
if v {
|
||||||
atomic.StoreUint32((*uint32)(b), 1)
|
atomic.StoreUint32((*uint32)(b), 1)
|
||||||
} else {
|
} else {
|
||||||
atomic.StoreUint32((*uint32)(b), 0)
|
atomic.StoreUint32((*uint32)(b), 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,38 +1,38 @@
|
|||||||
package wait
|
package wait
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Wait struct {
|
type Wait struct {
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wait)Add(delta int) {
|
func (w *Wait) Add(delta int) {
|
||||||
w.wg.Add(delta)
|
w.wg.Add(delta)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wait)Done() {
|
func (w *Wait) Done() {
|
||||||
w.wg.Done()
|
w.wg.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *Wait)Wait() {
|
func (w *Wait) Wait() {
|
||||||
w.wg.Wait()
|
w.wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// return isTimeout
|
// return isTimeout
|
||||||
func (w *Wait)WaitWithTimeout(timeout time.Duration)bool {
|
func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
|
||||||
c := make(chan bool)
|
c := make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(c)
|
defer close(c)
|
||||||
w.wg.Wait()
|
w.wg.Wait()
|
||||||
c <- true
|
c <- true
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
case <-c:
|
case <-c:
|
||||||
return false // completed normally
|
return false // completed normally
|
||||||
case <-time.After(timeout):
|
case <-time.After(timeout):
|
||||||
return true // timed out
|
return true // timed out
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,95 +1,95 @@
|
|||||||
package wildcard
|
package wildcard
|
||||||
|
|
||||||
const (
|
const (
|
||||||
normal = iota
|
normal = iota
|
||||||
all // *
|
all // *
|
||||||
any // ?
|
any // ?
|
||||||
set_ // []
|
set_ // []
|
||||||
)
|
)
|
||||||
|
|
||||||
type item struct {
|
type item struct {
|
||||||
character byte
|
character byte
|
||||||
set map[byte]bool
|
set map[byte]bool
|
||||||
typeCode int
|
typeCode int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *item) contains(c byte) bool {
|
func (i *item) contains(c byte) bool {
|
||||||
_, ok := i.set[c]
|
_, ok := i.set[c]
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
|
|
||||||
type Pattern struct {
|
type Pattern struct {
|
||||||
items []*item
|
items []*item
|
||||||
}
|
}
|
||||||
|
|
||||||
func CompilePattern(src string) *Pattern {
|
func CompilePattern(src string) *Pattern {
|
||||||
items := make([]*item, 0)
|
items := make([]*item, 0)
|
||||||
escape := false
|
escape := false
|
||||||
inSet := false
|
inSet := false
|
||||||
var set map[byte]bool
|
var set map[byte]bool
|
||||||
for _, v := range src {
|
for _, v := range src {
|
||||||
c := byte(v)
|
c := byte(v)
|
||||||
if escape {
|
if escape {
|
||||||
items = append(items, &item{typeCode: normal, character: c})
|
items = append(items, &item{typeCode: normal, character: c})
|
||||||
escape = false
|
escape = false
|
||||||
} else if c == '*' {
|
} else if c == '*' {
|
||||||
items = append(items, &item{typeCode: all})
|
items = append(items, &item{typeCode: all})
|
||||||
} else if c == '?' {
|
} else if c == '?' {
|
||||||
items = append(items, &item{typeCode: any})
|
items = append(items, &item{typeCode: any})
|
||||||
} else if c == '\\' {
|
} else if c == '\\' {
|
||||||
escape = true
|
escape = true
|
||||||
} else if c == '[' {
|
} else if c == '[' {
|
||||||
if !inSet {
|
if !inSet {
|
||||||
inSet = true
|
inSet = true
|
||||||
set = make(map[byte]bool)
|
set = make(map[byte]bool)
|
||||||
} else {
|
} else {
|
||||||
set[c] = true
|
set[c] = true
|
||||||
}
|
}
|
||||||
} else if c == ']' {
|
} else if c == ']' {
|
||||||
if inSet {
|
if inSet {
|
||||||
inSet = false
|
inSet = false
|
||||||
items = append(items, &item{typeCode: set_, set: set})
|
items = append(items, &item{typeCode: set_, set: set})
|
||||||
} else {
|
} else {
|
||||||
items = append(items, &item{typeCode: normal, character: c})
|
items = append(items, &item{typeCode: normal, character: c})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if inSet {
|
if inSet {
|
||||||
set[c] = true
|
set[c] = true
|
||||||
} else {
|
} else {
|
||||||
items = append(items, &item{typeCode: normal, character: c})
|
items = append(items, &item{typeCode: normal, character: c})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &Pattern{
|
return &Pattern{
|
||||||
items: items,
|
items: items,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Pattern) IsMatch(s string) bool {
|
func (p *Pattern) IsMatch(s string) bool {
|
||||||
if len(p.items) == 0 {
|
if len(p.items) == 0 {
|
||||||
return len(s) == 0
|
return len(s) == 0
|
||||||
}
|
}
|
||||||
m := len(s)
|
m := len(s)
|
||||||
n := len(p.items)
|
n := len(p.items)
|
||||||
table := make([][]bool, m+1)
|
table := make([][]bool, m+1)
|
||||||
for i := 0; i < m+1; i++ {
|
for i := 0; i < m+1; i++ {
|
||||||
table[i] = make([]bool, n+1)
|
table[i] = make([]bool, n+1)
|
||||||
}
|
}
|
||||||
table[0][0] = true
|
table[0][0] = true
|
||||||
for j := 1; j < n+1; j++ {
|
for j := 1; j < n+1; j++ {
|
||||||
table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all
|
table[0][j] = table[0][j-1] && p.items[j-1].typeCode == all
|
||||||
}
|
}
|
||||||
for i := 1; i < m+1; i++ {
|
for i := 1; i < m+1; i++ {
|
||||||
for j := 1; j < n+1; j++ {
|
for j := 1; j < n+1; j++ {
|
||||||
if p.items[j-1].typeCode == all {
|
if p.items[j-1].typeCode == all {
|
||||||
table[i][j] = table[i-1][j] || table[i][j-1]
|
table[i][j] = table[i-1][j] || table[i][j-1]
|
||||||
} else {
|
} else {
|
||||||
table[i][j] = table[i-1][j-1] &&
|
table[i][j] = table[i-1][j-1] &&
|
||||||
(p.items[j-1].typeCode == any ||
|
(p.items[j-1].typeCode == any ||
|
||||||
(p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) ||
|
(p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) ||
|
||||||
(p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1])))
|
(p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1])))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return table[m][n]
|
return table[m][n]
|
||||||
}
|
}
|
||||||
|
@@ -3,73 +3,73 @@ package wildcard
|
|||||||
import "testing"
|
import "testing"
|
||||||
|
|
||||||
func TestWildCard(t *testing.T) {
|
func TestWildCard(t *testing.T) {
|
||||||
p := CompilePattern("a")
|
p := CompilePattern("a")
|
||||||
if !p.IsMatch("a") {
|
if !p.IsMatch("a") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if p.IsMatch("b") {
|
if p.IsMatch("b") {
|
||||||
t.Error("expect false actually true")
|
t.Error("expect false actually true")
|
||||||
}
|
}
|
||||||
|
|
||||||
// test '?'
|
// test '?'
|
||||||
p = CompilePattern("a?")
|
p = CompilePattern("a?")
|
||||||
if !p.IsMatch("ab") {
|
if !p.IsMatch("ab") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if p.IsMatch("a") {
|
if p.IsMatch("a") {
|
||||||
t.Error("expect false actually true")
|
t.Error("expect false actually true")
|
||||||
}
|
}
|
||||||
if p.IsMatch("abb") {
|
if p.IsMatch("abb") {
|
||||||
t.Error("expect false actually true")
|
t.Error("expect false actually true")
|
||||||
}
|
}
|
||||||
if p.IsMatch("bb") {
|
if p.IsMatch("bb") {
|
||||||
t.Error("expect false actually true")
|
t.Error("expect false actually true")
|
||||||
}
|
}
|
||||||
|
|
||||||
// test *
|
// test *
|
||||||
p = CompilePattern("a*")
|
p = CompilePattern("a*")
|
||||||
if !p.IsMatch("ab") {
|
if !p.IsMatch("ab") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if !p.IsMatch("a") {
|
if !p.IsMatch("a") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if !p.IsMatch("abb") {
|
if !p.IsMatch("abb") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if p.IsMatch("bb") {
|
if p.IsMatch("bb") {
|
||||||
t.Error("expect false actually true")
|
t.Error("expect false actually true")
|
||||||
}
|
}
|
||||||
|
|
||||||
// test []
|
// test []
|
||||||
p = CompilePattern("a[ab[]")
|
p = CompilePattern("a[ab[]")
|
||||||
if !p.IsMatch("ab") {
|
if !p.IsMatch("ab") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if !p.IsMatch("aa") {
|
if !p.IsMatch("aa") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if !p.IsMatch("a[") {
|
if !p.IsMatch("a[") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if p.IsMatch("abb") {
|
if p.IsMatch("abb") {
|
||||||
t.Error("expect false actually true")
|
t.Error("expect false actually true")
|
||||||
}
|
}
|
||||||
if p.IsMatch("bb") {
|
if p.IsMatch("bb") {
|
||||||
t.Error("expect false actually true")
|
t.Error("expect false actually true")
|
||||||
}
|
}
|
||||||
|
|
||||||
// test escape
|
// test escape
|
||||||
p = CompilePattern("\\\\") // pattern: \\
|
p = CompilePattern("\\\\") // pattern: \\
|
||||||
if !p.IsMatch("\\") {
|
if !p.IsMatch("\\") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
|
|
||||||
p = CompilePattern("\\*")
|
p = CompilePattern("\\*")
|
||||||
if !p.IsMatch("*") {
|
if !p.IsMatch("*") {
|
||||||
t.Error("expect true actually false")
|
t.Error("expect true actually false")
|
||||||
}
|
}
|
||||||
if p.IsMatch("a") {
|
if p.IsMatch("a") {
|
||||||
t.Error("expect false actually true")
|
t.Error("expect false actually true")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,20 +1,20 @@
|
|||||||
package pubsub
|
package pubsub
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/HDT3213/godis/src/datastruct/dict"
|
"github.com/HDT3213/godis/src/datastruct/dict"
|
||||||
"github.com/HDT3213/godis/src/datastruct/lock"
|
"github.com/HDT3213/godis/src/datastruct/lock"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Hub struct {
|
type Hub struct {
|
||||||
// channel -> list(*Client)
|
// channel -> list(*Client)
|
||||||
subs dict.Dict
|
subs dict.Dict
|
||||||
// lock channel
|
// lock channel
|
||||||
subsLocker *lock.Locks
|
subsLocker *lock.Locks
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeHub() *Hub {
|
func MakeHub() *Hub {
|
||||||
return &Hub{
|
return &Hub{
|
||||||
subs: dict.MakeConcurrent(4),
|
subs: dict.MakeConcurrent(4),
|
||||||
subsLocker: lock.Make(16),
|
subsLocker: lock.Make(16),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,23 +1,23 @@
|
|||||||
package pubsub
|
package pubsub
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/HDT3213/godis/src/datastruct/list"
|
"github.com/HDT3213/godis/src/datastruct/list"
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
_subscribe = "subscribe"
|
_subscribe = "subscribe"
|
||||||
_unsubscribe = "unsubscribe"
|
_unsubscribe = "unsubscribe"
|
||||||
messageBytes = []byte("message")
|
messageBytes = []byte("message")
|
||||||
unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n")
|
unSubscribeNothing = []byte("*3\r\n$11\r\nunsubscribe\r\n$-1\n:0\r\n")
|
||||||
)
|
)
|
||||||
|
|
||||||
func makeMsg(t string, channel string, code int64) []byte {
|
func makeMsg(t string, channel string, code int64) []byte {
|
||||||
return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF +
|
return []byte("*3\r\n$" + strconv.FormatInt(int64(len(t)), 10) + reply.CRLF + t + reply.CRLF +
|
||||||
"$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF +
|
"$" + strconv.FormatInt(int64(len(channel)), 10) + reply.CRLF + channel + reply.CRLF +
|
||||||
":" + strconv.FormatInt(code, 10) + reply.CRLF)
|
":" + strconv.FormatInt(code, 10) + reply.CRLF)
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -25,22 +25,22 @@ func makeMsg(t string, channel string, code int64) []byte {
|
|||||||
* return: is new subscribed
|
* return: is new subscribed
|
||||||
*/
|
*/
|
||||||
func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
||||||
client.SubsChannel(channel)
|
client.SubsChannel(channel)
|
||||||
|
|
||||||
// add into hub.subs
|
// add into hub.subs
|
||||||
raw, ok := hub.subs.Get(channel)
|
raw, ok := hub.subs.Get(channel)
|
||||||
var subscribers *list.LinkedList
|
var subscribers *list.LinkedList
|
||||||
if ok {
|
if ok {
|
||||||
subscribers, _ = raw.(*list.LinkedList)
|
subscribers, _ = raw.(*list.LinkedList)
|
||||||
} else {
|
} else {
|
||||||
subscribers = list.Make()
|
subscribers = list.Make()
|
||||||
hub.subs.Put(channel, subscribers)
|
hub.subs.Put(channel, subscribers)
|
||||||
}
|
}
|
||||||
if subscribers.Contains(client) {
|
if subscribers.Contains(client) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
subscribers.Add(client)
|
subscribers.Add(client)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@@ -48,102 +48,102 @@ func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
|||||||
* return: is actually un-subscribe
|
* return: is actually un-subscribe
|
||||||
*/
|
*/
|
||||||
func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
||||||
client.UnSubsChannel(channel)
|
client.UnSubsChannel(channel)
|
||||||
|
|
||||||
// remove from hub.subs
|
// remove from hub.subs
|
||||||
raw, ok := hub.subs.Get(channel)
|
raw, ok := hub.subs.Get(channel)
|
||||||
if ok {
|
if ok {
|
||||||
subscribers, _ := raw.(*list.LinkedList)
|
subscribers, _ := raw.(*list.LinkedList)
|
||||||
subscribers.RemoveAllByVal(client)
|
subscribers.RemoveAllByVal(client)
|
||||||
|
|
||||||
if subscribers.Len() == 0 {
|
if subscribers.Len() == 0 {
|
||||||
// clean
|
// clean
|
||||||
hub.subs.Remove(channel)
|
hub.subs.Remove(channel)
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
|
func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
channels := make([]string, len(args))
|
channels := make([]string, len(args))
|
||||||
for i, b := range args {
|
for i, b := range args {
|
||||||
channels[i] = string(b)
|
channels[i] = string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
hub.subsLocker.Locks(channels...)
|
hub.subsLocker.Locks(channels...)
|
||||||
defer hub.subsLocker.UnLocks(channels...)
|
defer hub.subsLocker.UnLocks(channels...)
|
||||||
|
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
if subscribe0(hub, channel, c) {
|
if subscribe0(hub, channel, c) {
|
||||||
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
|
_ = c.Write(makeMsg(_subscribe, channel, int64(c.SubsCount())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &reply.NoReply{}
|
return &reply.NoReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnsubscribeAll(hub *Hub, c redis.Connection) {
|
func UnsubscribeAll(hub *Hub, c redis.Connection) {
|
||||||
channels := c.GetChannels()
|
channels := c.GetChannels()
|
||||||
|
|
||||||
hub.subsLocker.Locks(channels...)
|
hub.subsLocker.Locks(channels...)
|
||||||
defer hub.subsLocker.UnLocks(channels...)
|
defer hub.subsLocker.UnLocks(channels...)
|
||||||
|
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
unsubscribe0(hub, channel, c)
|
unsubscribe0(hub, channel, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
|
func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
|
||||||
var channels []string
|
var channels []string
|
||||||
if len(args) > 0 {
|
if len(args) > 0 {
|
||||||
channels = make([]string, len(args))
|
channels = make([]string, len(args))
|
||||||
for i, b := range args {
|
for i, b := range args {
|
||||||
channels[i] = string(b)
|
channels[i] = string(b)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
channels = c.GetChannels()
|
channels = c.GetChannels()
|
||||||
}
|
}
|
||||||
|
|
||||||
db.subsLocker.Locks(channels...)
|
db.subsLocker.Locks(channels...)
|
||||||
defer db.subsLocker.UnLocks(channels...)
|
defer db.subsLocker.UnLocks(channels...)
|
||||||
|
|
||||||
if len(channels) == 0 {
|
if len(channels) == 0 {
|
||||||
_ = c.Write(unSubscribeNothing)
|
_ = c.Write(unSubscribeNothing)
|
||||||
return &reply.NoReply{}
|
return &reply.NoReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
if unsubscribe0(db, channel, c) {
|
if unsubscribe0(db, channel, c) {
|
||||||
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
|
_ = c.Write(makeMsg(_unsubscribe, channel, int64(c.SubsCount())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &reply.NoReply{}
|
return &reply.NoReply{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Publish(hub *Hub, args [][]byte) redis.Reply {
|
func Publish(hub *Hub, args [][]byte) redis.Reply {
|
||||||
if len(args) != 2 {
|
if len(args) != 2 {
|
||||||
return &reply.ArgNumErrReply{Cmd: "publish"}
|
return &reply.ArgNumErrReply{Cmd: "publish"}
|
||||||
}
|
}
|
||||||
channel := string(args[0])
|
channel := string(args[0])
|
||||||
message := args[1]
|
message := args[1]
|
||||||
|
|
||||||
hub.subsLocker.Lock(channel)
|
hub.subsLocker.Lock(channel)
|
||||||
defer hub.subsLocker.UnLock(channel)
|
defer hub.subsLocker.UnLock(channel)
|
||||||
|
|
||||||
raw, ok := hub.subs.Get(channel)
|
raw, ok := hub.subs.Get(channel)
|
||||||
if !ok {
|
if !ok {
|
||||||
return reply.MakeIntReply(0)
|
return reply.MakeIntReply(0)
|
||||||
}
|
}
|
||||||
subscribers, _ := raw.(*list.LinkedList)
|
subscribers, _ := raw.(*list.LinkedList)
|
||||||
subscribers.ForEach(func(i int, c interface{}) bool {
|
subscribers.ForEach(func(i int, c interface{}) bool {
|
||||||
client, _ := c.(redis.Connection)
|
client, _ := c.(redis.Connection)
|
||||||
replyArgs := make([][]byte, 3)
|
replyArgs := make([][]byte, 3)
|
||||||
replyArgs[0] = messageBytes
|
replyArgs[0] = messageBytes
|
||||||
replyArgs[1] = []byte(channel)
|
replyArgs[1] = []byte(channel)
|
||||||
replyArgs[2] = message
|
replyArgs[2] = message
|
||||||
_ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes())
|
_ = client.Write(reply.MakeMultiBulkReply(replyArgs).ToBytes())
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return reply.MakeIntReply(int64(subscribers.Len()))
|
return reply.MakeIntReply(int64(subscribers.Len()))
|
||||||
}
|
}
|
||||||
|
@@ -1,322 +1,321 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"github.com/HDT3213/godis/src/lib/logger"
|
"github.com/HDT3213/godis/src/lib/logger"
|
||||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
sendingReqs chan *Request // waiting sending
|
sendingReqs chan *Request // waiting sending
|
||||||
waitingReqs chan *Request // waiting response
|
waitingReqs chan *Request // waiting response
|
||||||
ticker *time.Ticker
|
ticker *time.Ticker
|
||||||
addr string
|
addr string
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancelFunc context.CancelFunc
|
cancelFunc context.CancelFunc
|
||||||
writing *sync.WaitGroup
|
writing *sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
id uint64
|
id uint64
|
||||||
args [][]byte
|
args [][]byte
|
||||||
reply redis.Reply
|
reply redis.Reply
|
||||||
heartbeat bool
|
heartbeat bool
|
||||||
waiting *wait.Wait
|
waiting *wait.Wait
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chanSize = 256
|
chanSize = 256
|
||||||
maxWait = 3 * time.Second
|
maxWait = 3 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
func MakeClient(addr string) (*Client, error) {
|
func MakeClient(addr string) (*Client, error) {
|
||||||
conn, err := net.Dial("tcp", addr)
|
conn, err := net.Dial("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Client{
|
return &Client{
|
||||||
addr: addr,
|
addr: addr,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
sendingReqs: make(chan *Request, chanSize),
|
sendingReqs: make(chan *Request, chanSize),
|
||||||
waitingReqs: make(chan *Request, chanSize),
|
waitingReqs: make(chan *Request, chanSize),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancelFunc: cancel,
|
cancelFunc: cancel,
|
||||||
writing: &sync.WaitGroup{},
|
writing: &sync.WaitGroup{},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) Start() {
|
func (client *Client) Start() {
|
||||||
client.ticker = time.NewTicker(10 * time.Second)
|
client.ticker = time.NewTicker(10 * time.Second)
|
||||||
go client.handleWrite()
|
go client.handleWrite()
|
||||||
go func() {
|
go func() {
|
||||||
err := client.handleRead()
|
err := client.handleRead()
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
}()
|
}()
|
||||||
go client.heartbeat()
|
go client.heartbeat()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) Close() {
|
func (client *Client) Close() {
|
||||||
// stop new request
|
// stop new request
|
||||||
close(client.sendingReqs)
|
close(client.sendingReqs)
|
||||||
|
|
||||||
// wait stop process
|
// wait stop process
|
||||||
client.writing.Wait()
|
client.writing.Wait()
|
||||||
|
|
||||||
// clean
|
// clean
|
||||||
client.cancelFunc()
|
client.cancelFunc()
|
||||||
_ = client.conn.Close()
|
_ = client.conn.Close()
|
||||||
close(client.waitingReqs)
|
close(client.waitingReqs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) handleConnectionError(err error) error {
|
func (client *Client) handleConnectionError(err error) error {
|
||||||
err1 := client.conn.Close()
|
err1 := client.conn.Close()
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
if opErr, ok := err1.(*net.OpError); ok {
|
if opErr, ok := err1.(*net.OpError); ok {
|
||||||
if opErr.Err.Error() != "use of closed network connection" {
|
if opErr.Err.Error() != "use of closed network connection" {
|
||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
conn, err1 := net.Dial("tcp", client.addr)
|
conn, err1 := net.Dial("tcp", client.addr)
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
logger.Error(err1)
|
logger.Error(err1)
|
||||||
return err1
|
return err1
|
||||||
}
|
}
|
||||||
client.conn = conn
|
client.conn = conn
|
||||||
go func() {
|
go func() {
|
||||||
_ = client.handleRead()
|
_ = client.handleRead()
|
||||||
}()
|
}()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) heartbeat() {
|
func (client *Client) heartbeat() {
|
||||||
loop:
|
loop:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-client.ticker.C:
|
case <-client.ticker.C:
|
||||||
client.sendingReqs <- &Request{
|
client.sendingReqs <- &Request{
|
||||||
args: [][]byte{[]byte("PING")},
|
args: [][]byte{[]byte("PING")},
|
||||||
heartbeat: true,
|
heartbeat: true,
|
||||||
}
|
}
|
||||||
case <-client.ctx.Done():
|
case <-client.ctx.Done():
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) handleWrite() {
|
func (client *Client) handleWrite() {
|
||||||
loop:
|
loop:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case req := <-client.sendingReqs:
|
case req := <-client.sendingReqs:
|
||||||
client.writing.Add(1)
|
client.writing.Add(1)
|
||||||
client.doRequest(req)
|
client.doRequest(req)
|
||||||
case <-client.ctx.Done():
|
case <-client.ctx.Done():
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: wait with timeout
|
// todo: wait with timeout
|
||||||
func (client *Client) Send(args [][]byte) redis.Reply {
|
func (client *Client) Send(args [][]byte) redis.Reply {
|
||||||
request := &Request{
|
request := &Request{
|
||||||
args: args,
|
args: args,
|
||||||
heartbeat: false,
|
heartbeat: false,
|
||||||
waiting: &wait.Wait{},
|
waiting: &wait.Wait{},
|
||||||
}
|
}
|
||||||
request.waiting.Add(1)
|
request.waiting.Add(1)
|
||||||
client.sendingReqs <- request
|
client.sendingReqs <- request
|
||||||
timeout := request.waiting.WaitWithTimeout(maxWait)
|
timeout := request.waiting.WaitWithTimeout(maxWait)
|
||||||
if timeout {
|
if timeout {
|
||||||
return reply.MakeErrReply("server time out")
|
return reply.MakeErrReply("server time out")
|
||||||
}
|
}
|
||||||
if request.err != nil {
|
if request.err != nil {
|
||||||
return reply.MakeErrReply("request failed")
|
return reply.MakeErrReply("request failed")
|
||||||
}
|
}
|
||||||
return request.reply
|
return request.reply
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) doRequest(req *Request) {
|
func (client *Client) doRequest(req *Request) {
|
||||||
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
|
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
|
||||||
_, err := client.conn.Write(bytes)
|
_, err := client.conn.Write(bytes)
|
||||||
i := 0
|
i := 0
|
||||||
for err != nil && i < 3 {
|
for err != nil && i < 3 {
|
||||||
err = client.handleConnectionError(err)
|
err = client.handleConnectionError(err)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
_, err = client.conn.Write(bytes)
|
_, err = client.conn.Write(bytes)
|
||||||
}
|
}
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
if err == nil {
|
if err == nil {
|
||||||
client.waitingReqs <- req
|
client.waitingReqs <- req
|
||||||
} else {
|
} else {
|
||||||
req.err = err
|
req.err = err
|
||||||
req.waiting.Done()
|
req.waiting.Done()
|
||||||
client.writing.Done()
|
client.writing.Done()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) finishRequest(reply redis.Reply) {
|
func (client *Client) finishRequest(reply redis.Reply) {
|
||||||
request := <-client.waitingReqs
|
request := <-client.waitingReqs
|
||||||
request.reply = reply
|
request.reply = reply
|
||||||
if request.waiting != nil {
|
if request.waiting != nil {
|
||||||
request.waiting.Done()
|
request.waiting.Done()
|
||||||
}
|
}
|
||||||
client.writing.Done()
|
client.writing.Done()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (client *Client) handleRead() error {
|
func (client *Client) handleRead() error {
|
||||||
reader := bufio.NewReader(client.conn)
|
reader := bufio.NewReader(client.conn)
|
||||||
downloading := false
|
downloading := false
|
||||||
expectedArgsCount := 0
|
expectedArgsCount := 0
|
||||||
receivedCount := 0
|
receivedCount := 0
|
||||||
msgType := byte(0) // first char of msg
|
msgType := byte(0) // first char of msg
|
||||||
var args [][]byte
|
var args [][]byte
|
||||||
var fixedLen int64 = 0
|
var fixedLen int64 = 0
|
||||||
var err error
|
var err error
|
||||||
var msg []byte
|
var msg []byte
|
||||||
for {
|
for {
|
||||||
// read line
|
// read line
|
||||||
if fixedLen == 0 { // read normal line
|
if fixedLen == 0 { // read normal line
|
||||||
msg, err = reader.ReadBytes('\n')
|
msg, err = reader.ReadBytes('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
logger.Info("connection close")
|
logger.Info("connection close")
|
||||||
} else {
|
} else {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.New("connection closed")
|
return errors.New("connection closed")
|
||||||
}
|
}
|
||||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||||
return errors.New("protocol error")
|
return errors.New("protocol error")
|
||||||
}
|
}
|
||||||
} else { // read bulk line (binary safe)
|
} else { // read bulk line (binary safe)
|
||||||
msg = make([]byte, fixedLen+2)
|
msg = make([]byte, fixedLen+2)
|
||||||
_, err = io.ReadFull(reader, msg)
|
_, err = io.ReadFull(reader, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||||
return errors.New("connection closed")
|
return errors.New("connection closed")
|
||||||
} else {
|
} else {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(msg) == 0 ||
|
if len(msg) == 0 ||
|
||||||
msg[len(msg)-2] != '\r' ||
|
msg[len(msg)-2] != '\r' ||
|
||||||
msg[len(msg)-1] != '\n' {
|
msg[len(msg)-1] != '\n' {
|
||||||
return errors.New("protocol error")
|
return errors.New("protocol error")
|
||||||
}
|
}
|
||||||
fixedLen = 0
|
fixedLen = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse line
|
// parse line
|
||||||
if !downloading {
|
if !downloading {
|
||||||
// receive new response
|
// receive new response
|
||||||
if msg[0] == '*' { // multi bulk response
|
if msg[0] == '*' { // multi bulk response
|
||||||
// bulk multi msg
|
// bulk multi msg
|
||||||
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("protocol error: " + err.Error())
|
return errors.New("protocol error: " + err.Error())
|
||||||
}
|
}
|
||||||
if expectedLine == 0 {
|
if expectedLine == 0 {
|
||||||
client.finishRequest(&reply.EmptyMultiBulkReply{})
|
client.finishRequest(&reply.EmptyMultiBulkReply{})
|
||||||
} else if expectedLine > 0 {
|
} else if expectedLine > 0 {
|
||||||
msgType = msg[0]
|
msgType = msg[0]
|
||||||
downloading = true
|
downloading = true
|
||||||
expectedArgsCount = int(expectedLine)
|
expectedArgsCount = int(expectedLine)
|
||||||
receivedCount = 0
|
receivedCount = 0
|
||||||
args = make([][]byte, expectedLine)
|
args = make([][]byte, expectedLine)
|
||||||
} else {
|
} else {
|
||||||
return errors.New("protocol error")
|
return errors.New("protocol error")
|
||||||
}
|
}
|
||||||
} else if msg[0] == '$' { // bulk response
|
} else if msg[0] == '$' { // bulk response
|
||||||
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
|
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if fixedLen == -1 { // null bulk
|
if fixedLen == -1 { // null bulk
|
||||||
client.finishRequest(&reply.NullBulkReply{})
|
client.finishRequest(&reply.NullBulkReply{})
|
||||||
fixedLen = 0
|
fixedLen = 0
|
||||||
} else if fixedLen > 0 {
|
} else if fixedLen > 0 {
|
||||||
msgType = msg[0]
|
msgType = msg[0]
|
||||||
downloading = true
|
downloading = true
|
||||||
expectedArgsCount = 1
|
expectedArgsCount = 1
|
||||||
receivedCount = 0
|
receivedCount = 0
|
||||||
args = make([][]byte, 1)
|
args = make([][]byte, 1)
|
||||||
} else {
|
} else {
|
||||||
return errors.New("protocol error")
|
return errors.New("protocol error")
|
||||||
}
|
}
|
||||||
} else { // single line response
|
} else { // single line response
|
||||||
str := strings.TrimSuffix(string(msg), "\n")
|
str := strings.TrimSuffix(string(msg), "\n")
|
||||||
str = strings.TrimSuffix(str, "\r")
|
str = strings.TrimSuffix(str, "\r")
|
||||||
var result redis.Reply
|
var result redis.Reply
|
||||||
switch msg[0] {
|
switch msg[0] {
|
||||||
case '+':
|
case '+':
|
||||||
result = reply.MakeStatusReply(str[1:])
|
result = reply.MakeStatusReply(str[1:])
|
||||||
case '-':
|
case '-':
|
||||||
result = reply.MakeErrReply(str[1:])
|
result = reply.MakeErrReply(str[1:])
|
||||||
case ':':
|
case ':':
|
||||||
val, err := strconv.ParseInt(str[1:], 10, 64)
|
val, err := strconv.ParseInt(str[1:], 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("protocol error")
|
return errors.New("protocol error")
|
||||||
}
|
}
|
||||||
result = reply.MakeIntReply(val)
|
result = reply.MakeIntReply(val)
|
||||||
}
|
}
|
||||||
client.finishRequest(result)
|
client.finishRequest(result)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// receive following part of a request
|
// receive following part of a request
|
||||||
line := msg[0 : len(msg)-2]
|
line := msg[0 : len(msg)-2]
|
||||||
if line[0] == '$' {
|
if line[0] == '$' {
|
||||||
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if fixedLen <= 0 { // null bulk in multi bulks
|
if fixedLen <= 0 { // null bulk in multi bulks
|
||||||
args[receivedCount] = []byte{}
|
args[receivedCount] = []byte{}
|
||||||
receivedCount++
|
receivedCount++
|
||||||
fixedLen = 0
|
fixedLen = 0
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
args[receivedCount] = line
|
args[receivedCount] = line
|
||||||
receivedCount++
|
receivedCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
// if sending finished
|
// if sending finished
|
||||||
if receivedCount == expectedArgsCount {
|
if receivedCount == expectedArgsCount {
|
||||||
downloading = false // finish downloading progress
|
downloading = false // finish downloading progress
|
||||||
|
|
||||||
if msgType == '*' {
|
if msgType == '*' {
|
||||||
reply := reply.MakeMultiBulkReply(args)
|
reply := reply.MakeMultiBulkReply(args)
|
||||||
client.finishRequest(reply)
|
client.finishRequest(reply)
|
||||||
} else if msgType == '$' {
|
} else if msgType == '$' {
|
||||||
reply := reply.MakeBulkReply(args[0])
|
reply := reply.MakeBulkReply(args[0])
|
||||||
client.finishRequest(reply)
|
client.finishRequest(reply)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// finish reply
|
||||||
// finish reply
|
expectedArgsCount = 0
|
||||||
expectedArgsCount = 0
|
receivedCount = 0
|
||||||
receivedCount = 0
|
args = nil
|
||||||
args = nil
|
msgType = byte(0)
|
||||||
msgType = byte(0)
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@@ -1,104 +1,104 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/HDT3213/godis/src/lib/logger"
|
"github.com/HDT3213/godis/src/lib/logger"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestClient(t *testing.T) {
|
func TestClient(t *testing.T) {
|
||||||
logger.Setup(&logger.Settings{
|
logger.Setup(&logger.Settings{
|
||||||
Path: "logs",
|
Path: "logs",
|
||||||
Name: "godis",
|
Name: "godis",
|
||||||
Ext: ".log",
|
Ext: ".log",
|
||||||
TimeFormat: "2006-01-02",
|
TimeFormat: "2006-01-02",
|
||||||
})
|
})
|
||||||
client, err := MakeClient("localhost:6379")
|
client, err := MakeClient("localhost:6379")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
client.Start()
|
client.Start()
|
||||||
|
|
||||||
result := client.Send([][]byte{
|
result := client.Send([][]byte{
|
||||||
[]byte("PING"),
|
[]byte("PING"),
|
||||||
})
|
})
|
||||||
if statusRet, ok := result.(*reply.StatusReply); ok {
|
if statusRet, ok := result.(*reply.StatusReply); ok {
|
||||||
if statusRet.Status != "PONG" {
|
if statusRet.Status != "PONG" {
|
||||||
t.Error("`ping` failed, result: " + statusRet.Status)
|
t.Error("`ping` failed, result: " + statusRet.Status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = client.Send([][]byte{
|
result = client.Send([][]byte{
|
||||||
[]byte("SET"),
|
[]byte("SET"),
|
||||||
[]byte("a"),
|
[]byte("a"),
|
||||||
[]byte("a"),
|
[]byte("a"),
|
||||||
})
|
})
|
||||||
if statusRet, ok := result.(*reply.StatusReply); ok {
|
if statusRet, ok := result.(*reply.StatusReply); ok {
|
||||||
if statusRet.Status != "OK" {
|
if statusRet.Status != "OK" {
|
||||||
t.Error("`set` failed, result: " + statusRet.Status)
|
t.Error("`set` failed, result: " + statusRet.Status)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = client.Send([][]byte{
|
result = client.Send([][]byte{
|
||||||
[]byte("GET"),
|
[]byte("GET"),
|
||||||
[]byte("a"),
|
[]byte("a"),
|
||||||
})
|
})
|
||||||
if bulkRet, ok := result.(*reply.BulkReply); ok {
|
if bulkRet, ok := result.(*reply.BulkReply); ok {
|
||||||
if string(bulkRet.Arg) != "a" {
|
if string(bulkRet.Arg) != "a" {
|
||||||
t.Error("`get` failed, result: " + string(bulkRet.Arg))
|
t.Error("`get` failed, result: " + string(bulkRet.Arg))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = client.Send([][]byte{
|
result = client.Send([][]byte{
|
||||||
[]byte("DEL"),
|
[]byte("DEL"),
|
||||||
[]byte("a"),
|
[]byte("a"),
|
||||||
})
|
})
|
||||||
if intRet, ok := result.(*reply.IntReply); ok {
|
if intRet, ok := result.(*reply.IntReply); ok {
|
||||||
if intRet.Code != 1 {
|
if intRet.Code != 1 {
|
||||||
t.Error("`del` failed, result: " + string(intRet.Code))
|
t.Error("`del` failed, result: " + string(intRet.Code))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = client.Send([][]byte{
|
result = client.Send([][]byte{
|
||||||
[]byte("GET"),
|
[]byte("GET"),
|
||||||
[]byte("a"),
|
[]byte("a"),
|
||||||
})
|
})
|
||||||
if _, ok := result.(*reply.NullBulkReply); !ok {
|
if _, ok := result.(*reply.NullBulkReply); !ok {
|
||||||
t.Error("`get` failed, result: " + string(result.ToBytes()))
|
t.Error("`get` failed, result: " + string(result.ToBytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
result = client.Send([][]byte{
|
result = client.Send([][]byte{
|
||||||
[]byte("DEL"),
|
[]byte("DEL"),
|
||||||
[]byte("arr"),
|
[]byte("arr"),
|
||||||
})
|
})
|
||||||
|
|
||||||
result = client.Send([][]byte{
|
result = client.Send([][]byte{
|
||||||
[]byte("RPUSH"),
|
[]byte("RPUSH"),
|
||||||
[]byte("arr"),
|
[]byte("arr"),
|
||||||
[]byte("1"),
|
[]byte("1"),
|
||||||
[]byte("2"),
|
[]byte("2"),
|
||||||
[]byte("c"),
|
[]byte("c"),
|
||||||
})
|
})
|
||||||
if intRet, ok := result.(*reply.IntReply); ok {
|
if intRet, ok := result.(*reply.IntReply); ok {
|
||||||
if intRet.Code != 3 {
|
if intRet.Code != 3 {
|
||||||
t.Error("`rpush` failed, result: " + string(intRet.Code))
|
t.Error("`rpush` failed, result: " + string(intRet.Code))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result = client.Send([][]byte{
|
result = client.Send([][]byte{
|
||||||
[]byte("LRANGE"),
|
[]byte("LRANGE"),
|
||||||
[]byte("arr"),
|
[]byte("arr"),
|
||||||
[]byte("0"),
|
[]byte("0"),
|
||||||
[]byte("-1"),
|
[]byte("-1"),
|
||||||
})
|
})
|
||||||
if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok {
|
if multiBulkRet, ok := result.(*reply.MultiBulkReply); ok {
|
||||||
if len(multiBulkRet.Args) != 3 ||
|
if len(multiBulkRet.Args) != 3 ||
|
||||||
string(multiBulkRet.Args[0]) != "1" ||
|
string(multiBulkRet.Args[0]) != "1" ||
|
||||||
string(multiBulkRet.Args[1]) != "2" ||
|
string(multiBulkRet.Args[1]) != "2" ||
|
||||||
string(multiBulkRet.Args[2]) != "c" {
|
string(multiBulkRet.Args[2]) != "c" {
|
||||||
t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes()))
|
t.Error("`lrange` failed, result: " + string(multiBulkRet.ToBytes()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
client.Close()
|
client.Close()
|
||||||
}
|
}
|
||||||
|
@@ -1,42 +1,42 @@
|
|||||||
package reply
|
package reply
|
||||||
|
|
||||||
type PongReply struct {}
|
type PongReply struct{}
|
||||||
|
|
||||||
var PongBytes = []byte("+PONG\r\n")
|
var PongBytes = []byte("+PONG\r\n")
|
||||||
|
|
||||||
func (r *PongReply)ToBytes()[]byte {
|
func (r *PongReply) ToBytes() []byte {
|
||||||
return PongBytes
|
return PongBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
type OkReply struct {}
|
type OkReply struct{}
|
||||||
|
|
||||||
var okBytes = []byte("+OK\r\n")
|
var okBytes = []byte("+OK\r\n")
|
||||||
|
|
||||||
func (r *OkReply)ToBytes()[]byte {
|
func (r *OkReply) ToBytes() []byte {
|
||||||
return okBytes
|
return okBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
var nullBulkBytes = []byte("$-1\r\n")
|
var nullBulkBytes = []byte("$-1\r\n")
|
||||||
|
|
||||||
type NullBulkReply struct {}
|
type NullBulkReply struct{}
|
||||||
|
|
||||||
func (r *NullBulkReply)ToBytes()[]byte {
|
func (r *NullBulkReply) ToBytes() []byte {
|
||||||
return nullBulkBytes
|
return nullBulkBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
var emptyMultiBulkBytes = []byte("*0\r\n")
|
var emptyMultiBulkBytes = []byte("*0\r\n")
|
||||||
|
|
||||||
type EmptyMultiBulkReply struct {}
|
type EmptyMultiBulkReply struct{}
|
||||||
|
|
||||||
func (r *EmptyMultiBulkReply)ToBytes()[]byte {
|
func (r *EmptyMultiBulkReply) ToBytes() []byte {
|
||||||
return emptyMultiBulkBytes
|
return emptyMultiBulkBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
// reply nothing, for commands like subscribe
|
// reply nothing, for commands like subscribe
|
||||||
type NoReply struct {}
|
type NoReply struct{}
|
||||||
|
|
||||||
var NoBytes = []byte("")
|
var NoBytes = []byte("")
|
||||||
|
|
||||||
func (r *NoReply)ToBytes()[]byte {
|
func (r *NoReply) ToBytes() []byte {
|
||||||
return NoBytes
|
return NoBytes
|
||||||
}
|
}
|
||||||
|
@@ -1,67 +1,67 @@
|
|||||||
package reply
|
package reply
|
||||||
|
|
||||||
// UnknownErr
|
// UnknownErr
|
||||||
type UnknownErrReply struct {}
|
type UnknownErrReply struct{}
|
||||||
|
|
||||||
var unknownErrBytes = []byte("-Err unknown\r\n")
|
var unknownErrBytes = []byte("-Err unknown\r\n")
|
||||||
|
|
||||||
func (r *UnknownErrReply)ToBytes()[]byte {
|
func (r *UnknownErrReply) ToBytes() []byte {
|
||||||
return unknownErrBytes
|
return unknownErrBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *UnknownErrReply) Error()string {
|
func (r *UnknownErrReply) Error() string {
|
||||||
return "Err unknown"
|
return "Err unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ArgNumErr
|
// ArgNumErr
|
||||||
type ArgNumErrReply struct {
|
type ArgNumErrReply struct {
|
||||||
Cmd string
|
Cmd string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ArgNumErrReply)ToBytes()[]byte {
|
func (r *ArgNumErrReply) ToBytes() []byte {
|
||||||
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
|
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ArgNumErrReply) Error()string {
|
func (r *ArgNumErrReply) Error() string {
|
||||||
return "ERR wrong number of arguments for '" + r.Cmd + "' command"
|
return "ERR wrong number of arguments for '" + r.Cmd + "' command"
|
||||||
}
|
}
|
||||||
|
|
||||||
// SyntaxErr
|
// SyntaxErr
|
||||||
type SyntaxErrReply struct {}
|
type SyntaxErrReply struct{}
|
||||||
|
|
||||||
var syntaxErrBytes = []byte("-Err syntax error\r\n")
|
var syntaxErrBytes = []byte("-Err syntax error\r\n")
|
||||||
|
|
||||||
func (r *SyntaxErrReply)ToBytes()[]byte {
|
func (r *SyntaxErrReply) ToBytes() []byte {
|
||||||
return syntaxErrBytes
|
return syntaxErrBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SyntaxErrReply)Error()string {
|
func (r *SyntaxErrReply) Error() string {
|
||||||
return "Err syntax error"
|
return "Err syntax error"
|
||||||
}
|
}
|
||||||
|
|
||||||
// WrongTypeErr
|
// WrongTypeErr
|
||||||
type WrongTypeErrReply struct {}
|
type WrongTypeErrReply struct{}
|
||||||
|
|
||||||
var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
|
var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
|
||||||
|
|
||||||
func (r *WrongTypeErrReply)ToBytes()[]byte {
|
func (r *WrongTypeErrReply) ToBytes() []byte {
|
||||||
return wrongTypeErrBytes
|
return wrongTypeErrBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *WrongTypeErrReply)Error()string {
|
func (r *WrongTypeErrReply) Error() string {
|
||||||
return "WRONGTYPE Operation against a key holding the wrong kind of value"
|
return "WRONGTYPE Operation against a key holding the wrong kind of value"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProtocolErr
|
// ProtocolErr
|
||||||
|
|
||||||
type ProtocolErrReply struct {
|
type ProtocolErrReply struct {
|
||||||
Msg string
|
Msg string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ProtocolErrReply)ToBytes()[]byte {
|
func (r *ProtocolErrReply) ToBytes() []byte {
|
||||||
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
|
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ProtocolErrReply) Error()string {
|
func (r *ProtocolErrReply) Error() string {
|
||||||
return "ERR Protocol error: '" + r.Msg
|
return "ERR Protocol error: '" + r.Msg
|
||||||
}
|
}
|
||||||
|
@@ -1,141 +1,140 @@
|
|||||||
package reply
|
package reply
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"github.com/HDT3213/godis/src/interface/redis"
|
"github.com/HDT3213/godis/src/interface/redis"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
nullBulkReplyBytes = []byte("$-1")
|
nullBulkReplyBytes = []byte("$-1")
|
||||||
CRLF = "\r\n"
|
CRLF = "\r\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* ---- Bulk Reply ---- */
|
/* ---- Bulk Reply ---- */
|
||||||
|
|
||||||
type BulkReply struct {
|
type BulkReply struct {
|
||||||
Arg []byte
|
Arg []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeBulkReply(arg []byte) *BulkReply {
|
func MakeBulkReply(arg []byte) *BulkReply {
|
||||||
return &BulkReply{
|
return &BulkReply{
|
||||||
Arg: arg,
|
Arg: arg,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *BulkReply) ToBytes() []byte {
|
func (r *BulkReply) ToBytes() []byte {
|
||||||
if len(r.Arg) == 0 {
|
if len(r.Arg) == 0 {
|
||||||
return nullBulkReplyBytes
|
return nullBulkReplyBytes
|
||||||
}
|
}
|
||||||
return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF)
|
return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---- Multi Bulk Reply ---- */
|
/* ---- Multi Bulk Reply ---- */
|
||||||
|
|
||||||
type MultiBulkReply struct {
|
type MultiBulkReply struct {
|
||||||
Args [][]byte
|
Args [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeMultiBulkReply(args [][]byte) *MultiBulkReply {
|
func MakeMultiBulkReply(args [][]byte) *MultiBulkReply {
|
||||||
return &MultiBulkReply{
|
return &MultiBulkReply{
|
||||||
Args: args,
|
Args: args,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MultiBulkReply) ToBytes() []byte {
|
func (r *MultiBulkReply) ToBytes() []byte {
|
||||||
argLen := len(r.Args)
|
argLen := len(r.Args)
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
|
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
|
||||||
for _, arg := range r.Args {
|
for _, arg := range r.Args {
|
||||||
if arg == nil {
|
if arg == nil {
|
||||||
buf.WriteString("$-1" + CRLF)
|
buf.WriteString("$-1" + CRLF)
|
||||||
} else {
|
} else {
|
||||||
buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF)
|
buf.WriteString("$" + strconv.Itoa(len(arg)) + CRLF + string(arg) + CRLF)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---- Multi Raw Reply ---- */
|
/* ---- Multi Raw Reply ---- */
|
||||||
|
|
||||||
type MultiRawReply struct {
|
type MultiRawReply struct {
|
||||||
Args [][]byte
|
Args [][]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeMultiRawReply(args [][]byte) *MultiRawReply {
|
func MakeMultiRawReply(args [][]byte) *MultiRawReply {
|
||||||
return &MultiRawReply{
|
return &MultiRawReply{
|
||||||
Args: args,
|
Args: args,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *MultiRawReply) ToBytes() []byte {
|
func (r *MultiRawReply) ToBytes() []byte {
|
||||||
argLen := len(r.Args)
|
argLen := len(r.Args)
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
|
buf.WriteString("*" + strconv.Itoa(argLen) + CRLF)
|
||||||
for _, arg := range r.Args {
|
for _, arg := range r.Args {
|
||||||
buf.Write(arg)
|
buf.Write(arg)
|
||||||
}
|
}
|
||||||
return buf.Bytes()
|
return buf.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---- Status Reply ---- */
|
/* ---- Status Reply ---- */
|
||||||
|
|
||||||
type StatusReply struct {
|
type StatusReply struct {
|
||||||
Status string
|
Status string
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeStatusReply(status string) *StatusReply {
|
func MakeStatusReply(status string) *StatusReply {
|
||||||
return &StatusReply{
|
return &StatusReply{
|
||||||
Status: status,
|
Status: status,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *StatusReply) ToBytes() []byte {
|
func (r *StatusReply) ToBytes() []byte {
|
||||||
return []byte("+" + r.Status + "\r\n")
|
return []byte("+" + r.Status + "\r\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ---- Int Reply ---- */
|
/* ---- Int Reply ---- */
|
||||||
|
|
||||||
type IntReply struct {
|
type IntReply struct {
|
||||||
Code int64
|
Code int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeIntReply(code int64) *IntReply {
|
func MakeIntReply(code int64) *IntReply {
|
||||||
return &IntReply{
|
return &IntReply{
|
||||||
Code: code,
|
Code: code,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *IntReply) ToBytes() []byte {
|
func (r *IntReply) ToBytes() []byte {
|
||||||
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
|
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* ---- Error Reply ---- */
|
/* ---- Error Reply ---- */
|
||||||
|
|
||||||
type ErrorReply interface {
|
type ErrorReply interface {
|
||||||
Error() string
|
Error() string
|
||||||
ToBytes() []byte
|
ToBytes() []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type StandardErrReply struct {
|
type StandardErrReply struct {
|
||||||
Status string
|
Status string
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeErrReply(status string) *StandardErrReply {
|
func MakeErrReply(status string) *StandardErrReply {
|
||||||
return &StandardErrReply{
|
return &StandardErrReply{
|
||||||
Status: status,
|
Status: status,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsErrorReply(reply redis.Reply) bool {
|
func IsErrorReply(reply redis.Reply) bool {
|
||||||
return reply.ToBytes()[0] == '-'
|
return reply.ToBytes()[0] == '-'
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *StandardErrReply) ToBytes() []byte {
|
func (r *StandardErrReply) ToBytes() []byte {
|
||||||
return []byte("-" + r.Status + "\r\n")
|
return []byte("-" + r.Status + "\r\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *StandardErrReply) Error() string {
|
func (r *StandardErrReply) Error() string {
|
||||||
return r.Status
|
return r.Status
|
||||||
}
|
}
|
||||||
|
@@ -1,95 +1,95 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// abstract of active client
|
// abstract of active client
|
||||||
type Client struct {
|
type Client struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
|
||||||
// waiting util reply finished
|
// waiting util reply finished
|
||||||
waitingReply wait.Wait
|
waitingReply wait.Wait
|
||||||
|
|
||||||
// is sending request in progress
|
// is sending request in progress
|
||||||
uploading atomic.AtomicBool
|
uploading atomic.AtomicBool
|
||||||
// multi bulk msg lineCount - 1(first line)
|
// multi bulk msg lineCount - 1(first line)
|
||||||
expectedArgsCount uint32
|
expectedArgsCount uint32
|
||||||
// sent line count, exclude first line
|
// sent line count, exclude first line
|
||||||
receivedCount uint32
|
receivedCount uint32
|
||||||
// sent lines, exclude first line
|
// sent lines, exclude first line
|
||||||
args [][]byte
|
args [][]byte
|
||||||
|
|
||||||
// lock while server sending response
|
// lock while server sending response
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
|
||||||
// subscribing channels
|
// subscribing channels
|
||||||
subs map[string]bool
|
subs map[string]bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client)Close()error {
|
func (c *Client) Close() error {
|
||||||
c.waitingReply.WaitWithTimeout(10 * time.Second)
|
c.waitingReply.WaitWithTimeout(10 * time.Second)
|
||||||
_ = c.conn.Close()
|
_ = c.conn.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeClient(conn net.Conn) *Client {
|
func MakeClient(conn net.Conn) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client)Write(b []byte)error {
|
func (c *Client) Write(b []byte) error {
|
||||||
if b == nil || len(b) == 0 {
|
if b == nil || len(b) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
_, err := c.conn.Write(b)
|
_, err := c.conn.Write(b)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client)SubsChannel(channel string) {
|
func (c *Client) SubsChannel(channel string) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
if c.subs == nil {
|
if c.subs == nil {
|
||||||
c.subs = make(map[string]bool)
|
c.subs = make(map[string]bool)
|
||||||
}
|
}
|
||||||
c.subs[channel] = true
|
c.subs[channel] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client)UnSubsChannel(channel string) {
|
func (c *Client) UnSubsChannel(channel string) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
if c.subs == nil {
|
if c.subs == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(c.subs, channel)
|
delete(c.subs, channel)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client)SubsCount()int {
|
func (c *Client) SubsCount() int {
|
||||||
if c.subs == nil {
|
if c.subs == nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
return len(c.subs)
|
return len(c.subs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client)GetChannels()[]string {
|
func (c *Client) GetChannels() []string {
|
||||||
if c.subs == nil {
|
if c.subs == nil {
|
||||||
return make([]string, 0)
|
return make([]string, 0)
|
||||||
}
|
}
|
||||||
channels := make([]string, len(c.subs))
|
channels := make([]string, len(c.subs))
|
||||||
i := 0
|
i := 0
|
||||||
for channel := range c.subs {
|
for channel := range c.subs {
|
||||||
channels[i] = channel
|
channels[i] = channel
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
return channels
|
return channels
|
||||||
}
|
}
|
||||||
|
@@ -5,192 +5,192 @@ package server
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"github.com/HDT3213/godis/src/cluster"
|
"github.com/HDT3213/godis/src/cluster"
|
||||||
"github.com/HDT3213/godis/src/config"
|
"github.com/HDT3213/godis/src/config"
|
||||||
DBImpl "github.com/HDT3213/godis/src/db"
|
DBImpl "github.com/HDT3213/godis/src/db"
|
||||||
"github.com/HDT3213/godis/src/interface/db"
|
"github.com/HDT3213/godis/src/interface/db"
|
||||||
"github.com/HDT3213/godis/src/lib/logger"
|
"github.com/HDT3213/godis/src/lib/logger"
|
||||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||||
"github.com/HDT3213/godis/src/redis/reply"
|
"github.com/HDT3213/godis/src/redis/reply"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
|
UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
activeConn sync.Map // *client -> placeholder
|
activeConn sync.Map // *client -> placeholder
|
||||||
db db.DB
|
db db.DB
|
||||||
closing atomic.AtomicBool // refusing new client and new request
|
closing atomic.AtomicBool // refusing new client and new request
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeHandler() *Handler {
|
func MakeHandler() *Handler {
|
||||||
var db db.DB
|
var db db.DB
|
||||||
if config.Properties.Peers != nil &&
|
if config.Properties.Peers != nil &&
|
||||||
len(config.Properties.Peers) > 0 {
|
len(config.Properties.Peers) > 0 {
|
||||||
db = cluster.MakeCluster()
|
db = cluster.MakeCluster()
|
||||||
} else {
|
} else {
|
||||||
db = DBImpl.MakeDB()
|
db = DBImpl.MakeDB()
|
||||||
}
|
}
|
||||||
return &Handler{
|
return &Handler{
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) closeClient(client *Client) {
|
func (h *Handler) closeClient(client *Client) {
|
||||||
_ = client.Close()
|
_ = client.Close()
|
||||||
h.db.AfterClientClose(client)
|
h.db.AfterClientClose(client)
|
||||||
h.activeConn.Delete(client)
|
h.activeConn.Delete(client)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||||
if h.closing.Get() {
|
if h.closing.Get() {
|
||||||
// closing handler refuse new connection
|
// closing handler refuse new connection
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
client := MakeClient(conn)
|
client := MakeClient(conn)
|
||||||
h.activeConn.Store(client, 1)
|
h.activeConn.Store(client, 1)
|
||||||
|
|
||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
var fixedLen int64 = 0
|
var fixedLen int64 = 0
|
||||||
var err error
|
var err error
|
||||||
var msg []byte
|
var msg []byte
|
||||||
for {
|
for {
|
||||||
if fixedLen == 0 {
|
if fixedLen == 0 {
|
||||||
msg, err = reader.ReadBytes('\n')
|
msg, err = reader.ReadBytes('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF ||
|
if err == io.EOF ||
|
||||||
err == io.ErrUnexpectedEOF ||
|
err == io.ErrUnexpectedEOF ||
|
||||||
strings.Contains(err.Error(), "use of closed network connection") {
|
strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
logger.Info("connection close")
|
logger.Info("connection close")
|
||||||
} else {
|
} else {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// after client close
|
// after client close
|
||||||
h.closeClient(client)
|
h.closeClient(client)
|
||||||
return // io error, disconnect with client
|
return // io error, disconnect with client
|
||||||
}
|
}
|
||||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||||
_, _ = client.conn.Write(errReply.ToBytes())
|
_, _ = client.conn.Write(errReply.ToBytes())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
msg = make([]byte, fixedLen+2)
|
msg = make([]byte, fixedLen+2)
|
||||||
_, err = io.ReadFull(reader, msg)
|
_, err = io.ReadFull(reader, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF ||
|
if err == io.EOF ||
|
||||||
err == io.ErrUnexpectedEOF ||
|
err == io.ErrUnexpectedEOF ||
|
||||||
strings.Contains(err.Error(), "use of closed network connection") {
|
strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
logger.Info("connection close")
|
logger.Info("connection close")
|
||||||
} else {
|
} else {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// after client close
|
// after client close
|
||||||
h.closeClient(client)
|
h.closeClient(client)
|
||||||
return // io error, disconnect with client
|
return // io error, disconnect with client
|
||||||
}
|
}
|
||||||
if len(msg) == 0 ||
|
if len(msg) == 0 ||
|
||||||
msg[len(msg)-2] != '\r' ||
|
msg[len(msg)-2] != '\r' ||
|
||||||
msg[len(msg)-1] != '\n' {
|
msg[len(msg)-1] != '\n' {
|
||||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||||
_, _ = client.conn.Write(errReply.ToBytes())
|
_, _ = client.conn.Write(errReply.ToBytes())
|
||||||
}
|
}
|
||||||
fixedLen = 0
|
fixedLen = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if !client.uploading.Get() {
|
if !client.uploading.Get() {
|
||||||
// new request
|
// new request
|
||||||
if msg[0] == '*' {
|
if msg[0] == '*' {
|
||||||
// bulk multi msg
|
// bulk multi msg
|
||||||
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_, _ = client.conn.Write(UnknownErrReplyBytes)
|
_, _ = client.conn.Write(UnknownErrReplyBytes)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
client.waitingReply.Add(1)
|
client.waitingReply.Add(1)
|
||||||
client.uploading.Set(true)
|
client.uploading.Set(true)
|
||||||
client.expectedArgsCount = uint32(expectedLine)
|
client.expectedArgsCount = uint32(expectedLine)
|
||||||
client.receivedCount = 0
|
client.receivedCount = 0
|
||||||
client.args = make([][]byte, expectedLine)
|
client.args = make([][]byte, expectedLine)
|
||||||
} else {
|
} else {
|
||||||
// text protocol
|
// text protocol
|
||||||
// remove \r or \n or \r\n in the end of line
|
// remove \r or \n or \r\n in the end of line
|
||||||
str := strings.TrimSuffix(string(msg), "\n")
|
str := strings.TrimSuffix(string(msg), "\n")
|
||||||
str = strings.TrimSuffix(str, "\r")
|
str = strings.TrimSuffix(str, "\r")
|
||||||
strs := strings.Split(str, " ")
|
strs := strings.Split(str, " ")
|
||||||
args := make([][]byte, len(strs))
|
args := make([][]byte, len(strs))
|
||||||
for i, s := range strs {
|
for i, s := range strs {
|
||||||
args[i] = []byte(s)
|
args[i] = []byte(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// send reply
|
// send reply
|
||||||
result := h.db.Exec(client, args)
|
result := h.db.Exec(client, args)
|
||||||
if result != nil {
|
if result != nil {
|
||||||
_ = client.Write(result.ToBytes())
|
_ = client.Write(result.ToBytes())
|
||||||
} else {
|
} else {
|
||||||
_ = client.Write(UnknownErrReplyBytes)
|
_ = client.Write(UnknownErrReplyBytes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// receive following part of a request
|
// receive following part of a request
|
||||||
line := msg[0 : len(msg)-2]
|
line := msg[0 : len(msg)-2]
|
||||||
if line[0] == '$' {
|
if line[0] == '$' {
|
||||||
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errReply := &reply.ProtocolErrReply{Msg: err.Error()}
|
errReply := &reply.ProtocolErrReply{Msg: err.Error()}
|
||||||
_, _ = client.conn.Write(errReply.ToBytes())
|
_, _ = client.conn.Write(errReply.ToBytes())
|
||||||
}
|
}
|
||||||
if fixedLen <= 0 {
|
if fixedLen <= 0 {
|
||||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||||
_, _ = client.conn.Write(errReply.ToBytes())
|
_, _ = client.conn.Write(errReply.ToBytes())
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
client.args[client.receivedCount] = line
|
client.args[client.receivedCount] = line
|
||||||
client.receivedCount++
|
client.receivedCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
// if sending finished
|
// if sending finished
|
||||||
if client.receivedCount == client.expectedArgsCount {
|
if client.receivedCount == client.expectedArgsCount {
|
||||||
client.uploading.Set(false) // finish sending progress
|
client.uploading.Set(false) // finish sending progress
|
||||||
|
|
||||||
// send reply
|
// send reply
|
||||||
result := h.db.Exec(client, client.args)
|
result := h.db.Exec(client, client.args)
|
||||||
if result != nil {
|
if result != nil {
|
||||||
_ = client.Write(result.ToBytes())
|
_ = client.Write(result.ToBytes())
|
||||||
} else {
|
} else {
|
||||||
_ = client.Write(UnknownErrReplyBytes)
|
_ = client.Write(UnknownErrReplyBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// finish reply
|
// finish reply
|
||||||
client.expectedArgsCount = 0
|
client.expectedArgsCount = 0
|
||||||
client.receivedCount = 0
|
client.receivedCount = 0
|
||||||
client.args = nil
|
client.args = nil
|
||||||
client.waitingReply.Done()
|
client.waitingReply.Done()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) Close() error {
|
func (h *Handler) Close() error {
|
||||||
logger.Info("handler shuting down...")
|
logger.Info("handler shuting down...")
|
||||||
h.closing.Set(true)
|
h.closing.Set(true)
|
||||||
// TODO: concurrent wait
|
// TODO: concurrent wait
|
||||||
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
||||||
client := key.(*Client)
|
client := key.(*Client)
|
||||||
_ = client.Close()
|
_ = client.Close()
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
h.db.Close()
|
h.db.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
119
src/tcp/echo.go
119
src/tcp/echo.go
@@ -5,79 +5,78 @@ package tcp
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"bufio"
|
"github.com/HDT3213/godis/src/lib/logger"
|
||||||
"github.com/HDT3213/godis/src/lib/logger"
|
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||||
"sync"
|
"github.com/HDT3213/godis/src/lib/sync/wait"
|
||||||
"io"
|
"io"
|
||||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
"net"
|
||||||
"time"
|
"sync"
|
||||||
"github.com/HDT3213/godis/src/lib/sync/wait"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type EchoHandler struct {
|
type EchoHandler struct {
|
||||||
activeConn sync.Map
|
activeConn sync.Map
|
||||||
closing atomic.AtomicBool
|
closing atomic.AtomicBool
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeEchoHandler()(*EchoHandler) {
|
func MakeEchoHandler() *EchoHandler {
|
||||||
return &EchoHandler{
|
return &EchoHandler{}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
Conn net.Conn
|
Conn net.Conn
|
||||||
Waiting wait.Wait
|
Waiting wait.Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client)Close()error {
|
func (c *Client) Close() error {
|
||||||
c.Waiting.WaitWithTimeout(10 * time.Second)
|
c.Waiting.WaitWithTimeout(10 * time.Second)
|
||||||
c.Conn.Close()
|
c.Conn.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *EchoHandler)Handle(ctx context.Context, conn net.Conn) {
|
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||||
if h.closing.Get() {
|
if h.closing.Get() {
|
||||||
// closing handler refuse new connection
|
// closing handler refuse new connection
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &Client {
|
client := &Client{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
}
|
}
|
||||||
h.activeConn.Store(client, 1)
|
h.activeConn.Store(client, 1)
|
||||||
|
|
||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
for {
|
for {
|
||||||
// may occurs: client EOF, client timeout, server early close
|
// may occurs: client EOF, client timeout, server early close
|
||||||
msg, err := reader.ReadString('\n')
|
msg, err := reader.ReadString('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
logger.Info("connection close")
|
logger.Info("connection close")
|
||||||
h.activeConn.Delete(client)
|
h.activeConn.Delete(client)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
client.Waiting.Add(1)
|
client.Waiting.Add(1)
|
||||||
//logger.Info("sleeping")
|
//logger.Info("sleeping")
|
||||||
//time.Sleep(10 * time.Second)
|
//time.Sleep(10 * time.Second)
|
||||||
b := []byte(msg)
|
b := []byte(msg)
|
||||||
conn.Write(b)
|
conn.Write(b)
|
||||||
client.Waiting.Done()
|
client.Waiting.Done()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *EchoHandler)Close()error {
|
func (h *EchoHandler) Close() error {
|
||||||
logger.Info("handler shuting down...")
|
logger.Info("handler shuting down...")
|
||||||
h.closing.Set(true)
|
h.closing.Set(true)
|
||||||
// TODO: concurrent wait
|
// TODO: concurrent wait
|
||||||
h.activeConn.Range(func(key interface{}, val interface{})bool {
|
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
||||||
client := key.(*Client)
|
client := key.(*Client)
|
||||||
client.Close()
|
client.Close()
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -5,75 +5,74 @@ package tcp
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/HDT3213/godis/src/interface/tcp"
|
"github.com/HDT3213/godis/src/interface/tcp"
|
||||||
"github.com/HDT3213/godis/src/lib/logger"
|
"github.com/HDT3213/godis/src/lib/logger"
|
||||||
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
"github.com/HDT3213/godis/src/lib/sync/atomic"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Address string `yaml:"address"`
|
Address string `yaml:"address"`
|
||||||
MaxConnect uint32 `yaml:"max-connect"`
|
MaxConnect uint32 `yaml:"max-connect"`
|
||||||
Timeout time.Duration `yaml:"timeout"`
|
Timeout time.Duration `yaml:"timeout"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListenAndServe(cfg *Config, handler tcp.Handler) {
|
func ListenAndServe(cfg *Config, handler tcp.Handler) {
|
||||||
listener, err := net.Listen("tcp", cfg.Address)
|
listener, err := net.Listen("tcp", cfg.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal(fmt.Sprintf("listen err: %v", err))
|
logger.Fatal(fmt.Sprintf("listen err: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// listen signal
|
// listen signal
|
||||||
var closing atomic.AtomicBool
|
var closing atomic.AtomicBool
|
||||||
sigCh := make(chan os.Signal, 1)
|
sigCh := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
|
signal.Notify(sigCh, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT)
|
||||||
go func() {
|
go func() {
|
||||||
sig := <-sigCh
|
sig := <-sigCh
|
||||||
switch sig {
|
switch sig {
|
||||||
case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
|
case syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT:
|
||||||
logger.Info("shuting down...")
|
logger.Info("shuting down...")
|
||||||
closing.Set(true)
|
closing.Set(true)
|
||||||
_ = listener.Close() // listener.Accept() will return err immediately
|
_ = listener.Close() // listener.Accept() will return err immediately
|
||||||
_ = handler.Close() // close connections
|
_ = handler.Close() // close connections
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// listen port
|
||||||
// listen port
|
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
|
||||||
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
|
defer func() {
|
||||||
defer func() {
|
// close during unexpected error
|
||||||
// close during unexpected error
|
_ = listener.Close()
|
||||||
_ = listener.Close()
|
_ = handler.Close()
|
||||||
_ = handler.Close()
|
}()
|
||||||
}()
|
ctx, _ := context.WithCancel(context.Background())
|
||||||
ctx, _ := context.WithCancel(context.Background())
|
var waitDone sync.WaitGroup
|
||||||
var waitDone sync.WaitGroup
|
for {
|
||||||
for {
|
conn, err := listener.Accept()
|
||||||
conn, err := listener.Accept()
|
if err != nil {
|
||||||
if err != nil {
|
if closing.Get() {
|
||||||
if closing.Get() {
|
logger.Info("waiting disconnect...")
|
||||||
logger.Info("waiting disconnect...")
|
waitDone.Wait()
|
||||||
waitDone.Wait()
|
return // handler will be closed by defer
|
||||||
return // handler will be closed by defer
|
}
|
||||||
}
|
logger.Error(fmt.Sprintf("accept err: %v", err))
|
||||||
logger.Error(fmt.Sprintf("accept err: %v", err))
|
continue
|
||||||
continue
|
}
|
||||||
}
|
// handle
|
||||||
// handle
|
logger.Info("accept link")
|
||||||
logger.Info("accept link")
|
waitDone.Add(1)
|
||||||
waitDone.Add(1)
|
go func() {
|
||||||
go func() {
|
defer func() {
|
||||||
defer func() {
|
waitDone.Done()
|
||||||
waitDone.Done()
|
}()
|
||||||
}()
|
handler.Handle(ctx, conn)
|
||||||
handler.Handle(ctx, conn)
|
}()
|
||||||
}()
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user