This commit is contained in:
hdt3213
2021-05-12 00:27:29 +08:00
parent 55ada39252
commit 97d7b84276
58 changed files with 490 additions and 343 deletions

View File

@@ -13,7 +13,8 @@ middleware using golang.
Gods implemented most features of redis, including 5 data structures, ttl, publish/subscribe, geo and AOF persistence. Gods implemented most features of redis, including 5 data structures, ttl, publish/subscribe, geo and AOF persistence.
Godis can run as a server side cluster which is transparent to client. You can connect to any node in the cluster to access all data in the cluster: Godis can run as a server side cluster which is transparent to client. You can connect to any node in the cluster to
access all data in the cluster:
If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html). If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html).

View File

@@ -1,4 +1,4 @@
# Godis # Godis
![license](https://img.shields.io/github/license/HDT3213/godis) ![license](https://img.shields.io/github/license/HDT3213/godis)
[![Build Status](https://travis-ci.org/HDT3213/godis.svg?branch=master)](https://travis-ci.org/HDT3213/godis) [![Build Status](https://travis-ci.org/HDT3213/godis.svg?branch=master)](https://travis-ci.org/HDT3213/godis)

View File

@@ -23,7 +23,7 @@ func TestAof(t *testing.T) {
defer func() { defer func() {
_ = os.Remove(aofFilename) _ = os.Remove(aofFilename)
}() }()
config.Properties = &config.PropertyHolder{ config.Properties = &config.ServerProperties{
AppendOnly: true, AppendOnly: true,
AppendFilename: aofFilename, AppendFilename: aofFilename,
} }
@@ -93,7 +93,7 @@ func TestRewriteAOF(t *testing.T) {
defer func() { defer func() {
_ = os.Remove(aofFilename) _ = os.Remove(aofFilename)
}() }()
config.Properties = &config.PropertyHolder{ config.Properties = &config.ServerProperties{
AppendOnly: true, AppendOnly: true,
AppendFilename: aofFilename, AppendFilename: aofFilename,
} }

View File

@@ -29,7 +29,7 @@ type Cluster struct {
db *godis.DB db *godis.DB
transactions *dict.SimpleDict // id -> Transaction transactions *dict.SimpleDict // id -> Transaction
idGenerator *idgenerator.IdGenerator idGenerator *idgenerator.IDGenerator
} }
const ( const (

View File

@@ -24,7 +24,7 @@ func Del(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
} }
// prepare // prepare
var errReply redis.Reply var errReply redis.Reply
txID := cluster.idGenerator.NextId() txID := cluster.idGenerator.NextID()
txIDStr := strconv.FormatInt(txID, 10) txIDStr := strconv.FormatInt(txID, 10)
rollback := false rollback := false
for peer, group := range groupMap { for peer, group := range groupMap {

View File

@@ -105,7 +105,7 @@ func MSet(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
//prepare //prepare
var errReply redis.Reply var errReply redis.Reply
txID := cluster.idGenerator.NextId() txID := cluster.idGenerator.NextID()
txIDStr := strconv.FormatInt(txID, 10) txIDStr := strconv.FormatInt(txID, 10)
rollback := false rollback := false
for peer, group := range groupMap { for peer, group := range groupMap {

View File

@@ -9,7 +9,7 @@ var testCluster = MakeTestCluster(nil)
func MakeTestCluster(peers []string) *Cluster { func MakeTestCluster(peers []string) *Cluster {
if config.Properties == nil { if config.Properties == nil {
config.Properties = &config.PropertyHolder{} config.Properties = &config.ServerProperties{}
} }
config.Properties.Self = "127.0.0.1:6399" config.Properties.Self = "127.0.0.1:6399"
config.Properties.Peers = peers config.Properties.Peers = peers

View File

@@ -10,7 +10,8 @@ import (
"strings" "strings"
) )
type PropertyHolder struct { // ServerProperties defines global config properties
type ServerProperties struct {
Bind string `cfg:"bind"` Bind string `cfg:"bind"`
Port int `cfg:"port"` Port int `cfg:"port"`
AppendOnly bool `cfg:"appendOnly"` AppendOnly bool `cfg:"appendOnly"`
@@ -22,19 +23,20 @@ type PropertyHolder struct {
Self string `cfg:"self"` Self string `cfg:"self"`
} }
var Properties *PropertyHolder // Properties holds global config properties
var Properties *ServerProperties
func init() { func init() {
// default config // default config
Properties = &PropertyHolder{ Properties = &ServerProperties{
Bind: "127.0.0.1", Bind: "127.0.0.1",
Port: 6379, Port: 6379,
AppendOnly: false, AppendOnly: false,
} }
} }
func parse(src io.Reader) *PropertyHolder { func parse(src io.Reader) *ServerProperties {
config := &PropertyHolder{} config := &ServerProperties{}
// read config file // read config file
rawMap := make(map[string]string) rawMap := make(map[string]string)
@@ -91,6 +93,7 @@ func parse(src io.Reader) *PropertyHolder {
return config return config
} }
// SetupConfig read config file and store properties into Properties
func SetupConfig(configFilename string) { func SetupConfig(configFilename string) {
file, err := os.Open(configFilename) file, err := os.Open(configFilename)
if err != nil { if err != nil {

View File

@@ -7,12 +7,13 @@ import (
"sync/atomic" "sync/atomic"
) )
// ConcurrentDict is thread safe map using sharding lock
type ConcurrentDict struct { type ConcurrentDict struct {
table []*Shard table []*shard
count int32 count int32
} }
type Shard struct { type shard struct {
m map[string]interface{} m map[string]interface{}
mutex sync.RWMutex mutex sync.RWMutex
} }
@@ -29,16 +30,16 @@ func computeCapacity(param int) (size int) {
n |= n >> 16 n |= n >> 16
if n < 0 { if n < 0 {
return math.MaxInt32 return math.MaxInt32
} else {
return int(n + 1)
} }
return n + 1
} }
// MakeConcurrent creates ConcurrentDict with the given shard count
func MakeConcurrent(shardCount int) *ConcurrentDict { func MakeConcurrent(shardCount int) *ConcurrentDict {
shardCount = computeCapacity(shardCount) shardCount = computeCapacity(shardCount)
table := make([]*Shard, shardCount) table := make([]*shard, shardCount)
for i := 0; i < shardCount; i++ { for i := 0; i < shardCount; i++ {
table[i] = &Shard{ table[i] = &shard{
m: make(map[string]interface{}), m: make(map[string]interface{}),
} }
} }
@@ -68,13 +69,14 @@ func (dict *ConcurrentDict) spread(hashCode uint32) uint32 {
return (tableSize - 1) & uint32(hashCode) return (tableSize - 1) & uint32(hashCode)
} }
func (dict *ConcurrentDict) getShard(index uint32) *Shard { func (dict *ConcurrentDict) getShard(index uint32) *shard {
if dict == nil { if dict == nil {
panic("dict is nil") panic("dict is nil")
} }
return dict.table[index] return dict.table[index]
} }
// Get returns the binding value and whether the key is exist
func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) { func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) {
if dict == nil { if dict == nil {
panic("dict is nil") panic("dict is nil")
@@ -88,6 +90,7 @@ func (dict *ConcurrentDict) Get(key string) (val interface{}, exists bool) {
return return
} }
// Len returns the number of dict
func (dict *ConcurrentDict) Len() int { func (dict *ConcurrentDict) Len() int {
if dict == nil { if dict == nil {
panic("dict is nil") panic("dict is nil")
@@ -95,7 +98,7 @@ func (dict *ConcurrentDict) Len() int {
return int(atomic.LoadInt32(&dict.count)) return int(atomic.LoadInt32(&dict.count))
} }
// return the number of new inserted key-value // Put puts key value into dict and returns the number of new inserted key-value
func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) { func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) {
if dict == nil { if dict == nil {
panic("dict is nil") panic("dict is nil")
@@ -109,14 +112,13 @@ func (dict *ConcurrentDict) Put(key string, val interface{}) (result int) {
if _, ok := shard.m[key]; ok { if _, ok := shard.m[key]; ok {
shard.m[key] = val shard.m[key] = val
return 0 return 0
} else {
shard.m[key] = val
dict.addCount()
return 1
} }
shard.m[key] = val
dict.addCount()
return 1
} }
// return the number of updated key-value // PutIfAbsent puts value if the key is not exists and returns the number of updated key-value
func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int) { func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int) {
if dict == nil { if dict == nil {
panic("dict is nil") panic("dict is nil")
@@ -129,14 +131,13 @@ func (dict *ConcurrentDict) PutIfAbsent(key string, val interface{}) (result int
if _, ok := shard.m[key]; ok { if _, ok := shard.m[key]; ok {
return 0 return 0
} else {
shard.m[key] = val
dict.addCount()
return 1
} }
shard.m[key] = val
dict.addCount()
return 1
} }
// return the number of updated key-value // PutIfExists puts value if the key is exist and returns the number of inserted key-value
func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int) { func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int) {
if dict == nil { if dict == nil {
panic("dict is nil") panic("dict is nil")
@@ -150,12 +151,11 @@ func (dict *ConcurrentDict) PutIfExists(key string, val interface{}) (result int
if _, ok := shard.m[key]; ok { if _, ok := shard.m[key]; ok {
shard.m[key] = val shard.m[key] = val
return 1 return 1
} else {
return 0
} }
return 0
} }
// return the number of deleted key-value // Remove removes the key and return the number of deleted key-value
func (dict *ConcurrentDict) Remove(key string) (result int) { func (dict *ConcurrentDict) Remove(key string) (result int) {
if dict == nil { if dict == nil {
panic("dict is nil") panic("dict is nil")
@@ -169,19 +169,16 @@ func (dict *ConcurrentDict) Remove(key string) (result int) {
if _, ok := shard.m[key]; ok { if _, ok := shard.m[key]; ok {
delete(shard.m, key) delete(shard.m, key)
return 1 return 1
} else {
return 0
} }
return return 0
} }
func (dict *ConcurrentDict) addCount() int32 { func (dict *ConcurrentDict) addCount() int32 {
return atomic.AddInt32(&dict.count, 1) return atomic.AddInt32(&dict.count, 1)
} }
/* // ForEach traversal the dict
* may not contains new entry inserted during traversal // it may not visits new entry inserted during traversal
*/
func (dict *ConcurrentDict) ForEach(consumer Consumer) { func (dict *ConcurrentDict) ForEach(consumer Consumer) {
if dict == nil { if dict == nil {
panic("dict is nil") panic("dict is nil")
@@ -201,6 +198,7 @@ func (dict *ConcurrentDict) ForEach(consumer Consumer) {
} }
} }
// Keys returns all keys in dict
func (dict *ConcurrentDict) Keys() []string { func (dict *ConcurrentDict) Keys() []string {
keys := make([]string, dict.Len()) keys := make([]string, dict.Len())
i := 0 i := 0
@@ -216,7 +214,8 @@ func (dict *ConcurrentDict) Keys() []string {
return keys return keys
} }
func (shard *Shard) RandomKey() string { // RandomKey returns a key randomly
func (shard *shard) RandomKey() string {
if shard == nil { if shard == nil {
panic("shard is nil") panic("shard is nil")
} }
@@ -229,6 +228,7 @@ func (shard *Shard) RandomKey() string {
return "" return ""
} }
// RandomKeys randomly returns keys of the given number, may contain duplicated key
func (dict *ConcurrentDict) RandomKeys(limit int) []string { func (dict *ConcurrentDict) RandomKeys(limit int) []string {
size := dict.Len() size := dict.Len()
if limit >= size { if limit >= size {
@@ -251,6 +251,7 @@ func (dict *ConcurrentDict) RandomKeys(limit int) []string {
return result return result
} }
// RandomDistinctKeys randomly returns keys of the given number, won't contain duplicated key
func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string { func (dict *ConcurrentDict) RandomDistinctKeys(limit int) []string {
size := dict.Len() size := dict.Len()
if limit >= size { if limit >= size {

View File

@@ -278,4 +278,4 @@ func TestConcurrentDict_Keys(t *testing.T) {
if len(d.Keys()) != size { if len(d.Keys()) != size {
t.Errorf("expect %d keys, actual: %d", size, len(d.Keys())) t.Errorf("expect %d keys, actual: %d", size, len(d.Keys()))
} }
} }

View File

@@ -1,7 +1,9 @@
package dict package dict
// Consumer is used to traversal dict, if it returns false the traversal will be break
type Consumer func(key string, val interface{}) bool type Consumer func(key string, val interface{}) bool
// Dict is interface of a key-value data structure
type Dict interface { type Dict interface {
Get(key string) (val interface{}, exists bool) Get(key string) (val interface{}, exists bool)
Len() int Len() int

View File

@@ -1,20 +1,24 @@
package dict package dict
// SimpleDict wraps a map, it is not thread safe
type SimpleDict struct { type SimpleDict struct {
m map[string]interface{} m map[string]interface{}
} }
// MakeSimple makes a new map
func MakeSimple() *SimpleDict { func MakeSimple() *SimpleDict {
return &SimpleDict{ return &SimpleDict{
m: make(map[string]interface{}), m: make(map[string]interface{}),
} }
} }
// Get returns the binding value and whether the key is exist
func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) { func (dict *SimpleDict) Get(key string) (val interface{}, exists bool) {
val, ok := dict.m[key] val, ok := dict.m[key]
return val, ok return val, ok
} }
// Len returns the number of dict
func (dict *SimpleDict) Len() int { func (dict *SimpleDict) Len() int {
if dict.m == nil { if dict.m == nil {
panic("m is nil") panic("m is nil")
@@ -22,46 +26,47 @@ func (dict *SimpleDict) Len() int {
return len(dict.m) return len(dict.m)
} }
// Put puts key value into dict and returns the number of new inserted key-value
func (dict *SimpleDict) Put(key string, val interface{}) (result int) { func (dict *SimpleDict) Put(key string, val interface{}) (result int) {
_, existed := dict.m[key] _, existed := dict.m[key]
dict.m[key] = val dict.m[key] = val
if existed { if existed {
return 0 return 0
} else {
return 1
} }
return 1
} }
// PutIfAbsent puts value if the key is not exists and returns the number of updated key-value
func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) { func (dict *SimpleDict) PutIfAbsent(key string, val interface{}) (result int) {
_, existed := dict.m[key] _, existed := dict.m[key]
if existed { if existed {
return 0 return 0
} else {
dict.m[key] = val
return 1
} }
dict.m[key] = val
return 1
} }
// PutIfExists puts value if the key is exist and returns the number of inserted key-value
func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) { func (dict *SimpleDict) PutIfExists(key string, val interface{}) (result int) {
_, existed := dict.m[key] _, existed := dict.m[key]
if existed { if existed {
dict.m[key] = val dict.m[key] = val
return 1 return 1
} else {
return 0
} }
return 0
} }
// Remove removes the key and return the number of deleted key-value
func (dict *SimpleDict) Remove(key string) (result int) { func (dict *SimpleDict) Remove(key string) (result int) {
_, existed := dict.m[key] _, existed := dict.m[key]
delete(dict.m, key) delete(dict.m, key)
if existed { if existed {
return 1 return 1
} else {
return 0
} }
return 0
} }
// Keys returns all keys in dict
func (dict *SimpleDict) Keys() []string { func (dict *SimpleDict) Keys() []string {
result := make([]string, len(dict.m)) result := make([]string, len(dict.m))
i := 0 i := 0
@@ -71,6 +76,7 @@ func (dict *SimpleDict) Keys() []string {
return result return result
} }
// ForEach traversal the dict
func (dict *SimpleDict) ForEach(consumer Consumer) { func (dict *SimpleDict) ForEach(consumer Consumer) {
for k, v := range dict.m { for k, v := range dict.m {
if !consumer(k, v) { if !consumer(k, v) {
@@ -79,6 +85,7 @@ func (dict *SimpleDict) ForEach(consumer Consumer) {
} }
} }
// RandomKeys randomly returns keys of the given number, may contain duplicated key
func (dict *SimpleDict) RandomKeys(limit int) []string { func (dict *SimpleDict) RandomKeys(limit int) []string {
result := make([]string, limit) result := make([]string, limit)
for i := 0; i < limit; i++ { for i := 0; i < limit; i++ {
@@ -90,6 +97,7 @@ func (dict *SimpleDict) RandomKeys(limit int) []string {
return result return result
} }
// RandomDistinctKeys randomly returns keys of the given number, won't contain duplicated key
func (dict *SimpleDict) RandomDistinctKeys(limit int) []string { func (dict *SimpleDict) RandomDistinctKeys(limit int) []string {
size := limit size := limit
if size > len(dict.m) { if size > len(dict.m) {

View File

@@ -2,6 +2,7 @@ package list
import "github.com/hdt3213/godis/datastruct/utils" import "github.com/hdt3213/godis/datastruct/utils"
// LinkedList is doubly linked list
type LinkedList struct { type LinkedList struct {
first *node first *node
last *node last *node
@@ -14,6 +15,7 @@ type node struct {
next *node next *node
} }
// Add adds value to the tail
func (list *LinkedList) Add(val interface{}) { func (list *LinkedList) Add(val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -35,20 +37,20 @@ func (list *LinkedList) Add(val interface{}) {
func (list *LinkedList) find(index int) (n *node) { func (list *LinkedList) find(index int) (n *node) {
if index < list.size/2 { if index < list.size/2 {
n := list.first n = list.first
for i := 0; i < index; i++ { for i := 0; i < index; i++ {
n = n.next n = n.next
} }
return n
} else { } else {
n := list.last n = list.last
for i := list.size - 1; i > index; i-- { for i := list.size - 1; i > index; i-- {
n = n.prev n = n.prev
} }
return n
} }
return n
} }
// Get returns value at the given index
func (list *LinkedList) Get(index int) (val interface{}) { func (list *LinkedList) Get(index int) (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -59,6 +61,7 @@ func (list *LinkedList) Get(index int) (val interface{}) {
return list.find(index).val return list.find(index).val
} }
// Set updates value at the given index, the index should between [0, list.size]
func (list *LinkedList) Set(index int, val interface{}) { func (list *LinkedList) Set(index int, val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -70,6 +73,7 @@ func (list *LinkedList) Set(index int, val interface{}) {
n.val = val n.val = val
} }
// Insert inserts value at the given index, the original element at the given index will move backward
func (list *LinkedList) Insert(index int, val interface{}) { func (list *LinkedList) Insert(index int, val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -81,22 +85,21 @@ func (list *LinkedList) Insert(index int, val interface{}) {
if index == list.size { if index == list.size {
list.Add(val) list.Add(val)
return return
} else {
// list is not empty
pivot := list.find(index)
n := &node{
val: val,
prev: pivot.prev,
next: pivot,
}
if pivot.prev == nil {
list.first = n
} else {
pivot.prev.next = n
}
pivot.prev = n
list.size++
} }
// list is not empty
pivot := list.find(index)
n := &node{
val: val,
prev: pivot.prev,
next: pivot,
}
if pivot.prev == nil {
list.first = n
} else {
pivot.prev.next = n
}
pivot.prev = n
list.size++
} }
func (list *LinkedList) removeNode(n *node) { func (list *LinkedList) removeNode(n *node) {
@@ -118,6 +121,7 @@ func (list *LinkedList) removeNode(n *node) {
list.size-- list.size--
} }
// Remove removes value at the given index
func (list *LinkedList) Remove(index int) (val interface{}) { func (list *LinkedList) Remove(index int) (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -131,6 +135,7 @@ func (list *LinkedList) Remove(index int) (val interface{}) {
return n.val return n.val
} }
// RemoveLast removes the last element and returns its value
func (list *LinkedList) RemoveLast() (val interface{}) { func (list *LinkedList) RemoveLast() (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -144,6 +149,7 @@ func (list *LinkedList) RemoveLast() (val interface{}) {
return n.val return n.val
} }
// RemoveAllByVal removes all elements with the given val
func (list *LinkedList) RemoveAllByVal(val interface{}) int { func (list *LinkedList) RemoveAllByVal(val interface{}) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -172,10 +178,8 @@ func (list *LinkedList) RemoveAllByVal(val interface{}) int {
return removed return removed
} }
/** // RemoveByVal removes at most `count` values of the specified value in this list
* remove at most `count` values of the specified value in this list // scan from left to right
* scan from left to right
*/
func (list *LinkedList) RemoveByVal(val interface{}, count int) int { func (list *LinkedList) RemoveByVal(val interface{}, count int) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -208,6 +212,8 @@ func (list *LinkedList) RemoveByVal(val interface{}, count int) int {
return removed return removed
} }
// ReverseRemoveByVal removes at most `count` values of the specified value in this list
// scan from right to left
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int { func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -240,6 +246,7 @@ func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int {
return removed return removed
} }
// Len returns the number of elements in list
func (list *LinkedList) Len() int { func (list *LinkedList) Len() int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -247,6 +254,8 @@ func (list *LinkedList) Len() int {
return list.size return list.size
} }
// ForEach visits each element in the list
// if the consumer returns false, the loop will be break
func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) { func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -264,6 +273,7 @@ func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) {
} }
} }
// Contains returns whether the given value exist in the list
func (list *LinkedList) Contains(val interface{}) bool { func (list *LinkedList) Contains(val interface{}) bool {
contains := false contains := false
list.ForEach(func(i int, actual interface{}) bool { list.ForEach(func(i int, actual interface{}) bool {
@@ -276,6 +286,7 @@ func (list *LinkedList) Contains(val interface{}) bool {
return contains return contains
} }
// Range returns elements which index within [start, stop)
func (list *LinkedList) Range(start int, stop int) []interface{} { func (list *LinkedList) Range(start int, stop int) []interface{} {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
@@ -309,6 +320,7 @@ func (list *LinkedList) Range(start int, stop int) []interface{} {
return slice return slice
} }
// Make creates a new linked list
func Make(vals ...interface{}) *LinkedList { func Make(vals ...interface{}) *LinkedList {
list := LinkedList{} list := LinkedList{}
for _, v := range vals { for _, v := range vals {

View File

@@ -9,10 +9,12 @@ const (
prime32 = uint32(16777619) prime32 = uint32(16777619)
) )
// Locks provides rw locks for key
type Locks struct { type Locks struct {
table []*sync.RWMutex table []*sync.RWMutex
} }
// Make creates a new lock map
func Make(tableSize int) *Locks { func Make(tableSize int) *Locks {
table := make([]*sync.RWMutex, tableSize) table := make([]*sync.RWMutex, tableSize)
for i := 0; i < tableSize; i++ { for i := 0; i < tableSize; i++ {
@@ -40,24 +42,28 @@ func (locks *Locks) spread(hashCode uint32) uint32 {
return (tableSize - 1) & uint32(hashCode) return (tableSize - 1) & uint32(hashCode)
} }
// Lock obtains exclusive lock for writing
func (locks *Locks) Lock(key string) { func (locks *Locks) Lock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.Lock() mu.Lock()
} }
// RLock obtains shared lock for reading
func (locks *Locks) RLock(key string) { func (locks *Locks) RLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.RLock() mu.RLock()
} }
// UnLock release exclusive lock
func (locks *Locks) UnLock(key string) { func (locks *Locks) UnLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.Unlock() mu.Unlock()
} }
// RUnLock release shared lock
func (locks *Locks) RUnLock(key string) { func (locks *Locks) RUnLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
@@ -77,13 +83,14 @@ func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 {
sort.Slice(indices, func(i, j int) bool { sort.Slice(indices, func(i, j int) bool {
if !reverse { if !reverse {
return indices[i] < indices[j] return indices[i] < indices[j]
} else {
return indices[i] > indices[j]
} }
return indices[i] > indices[j]
}) })
return indices return indices
} }
// Locks obtains multiple exclusive locks for writing
// invoking Lock in loop may cause dead lock, please use Locks
func (locks *Locks) Locks(keys ...string) { func (locks *Locks) Locks(keys ...string) {
indices := locks.toLockIndices(keys, false) indices := locks.toLockIndices(keys, false)
for _, index := range indices { for _, index := range indices {
@@ -92,6 +99,8 @@ func (locks *Locks) Locks(keys ...string) {
} }
} }
// RLocks obtains multiple shared locks for reading
// invoking RLock in loop may cause dead lock, please use RLocks
func (locks *Locks) RLocks(keys ...string) { func (locks *Locks) RLocks(keys ...string) {
indices := locks.toLockIndices(keys, false) indices := locks.toLockIndices(keys, false)
for _, index := range indices { for _, index := range indices {
@@ -100,6 +109,7 @@ func (locks *Locks) RLocks(keys ...string) {
} }
} }
// UnLocks releases multiple exclusive locks
func (locks *Locks) UnLocks(keys ...string) { func (locks *Locks) UnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true) indices := locks.toLockIndices(keys, true)
for _, index := range indices { for _, index := range indices {
@@ -108,6 +118,7 @@ func (locks *Locks) UnLocks(keys ...string) {
} }
} }
// RUnLocks releases multiple shared locks
func (locks *Locks) RUnLocks(keys ...string) { func (locks *Locks) RUnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true) indices := locks.toLockIndices(keys, true)
for _, index := range indices { for _, index := range indices {

View File

@@ -2,19 +2,15 @@ package set
import "github.com/hdt3213/godis/datastruct/dict" import "github.com/hdt3213/godis/datastruct/dict"
// Set is a set of elements based on hash table
type Set struct { type Set struct {
dict dict.Dict dict dict.Dict
} }
func Make() *Set { // Make creates a new set
return &Set{ func Make(members ...string) *Set {
dict: dict.MakeSimple(),
}
}
func MakeFromVals(members ...string) *Set {
set := &Set{ set := &Set{
dict: dict.MakeConcurrent(len(members)), dict: dict.MakeSimple(),
} }
for _, member := range members { for _, member := range members {
set.Add(member) set.Add(member)
@@ -22,23 +18,28 @@ func MakeFromVals(members ...string) *Set {
return set return set
} }
// Add adds member into set
func (set *Set) Add(val string) int { func (set *Set) Add(val string) int {
return set.dict.Put(val, nil) return set.dict.Put(val, nil)
} }
// Remove removes member from set
func (set *Set) Remove(val string) int { func (set *Set) Remove(val string) int {
return set.dict.Remove(val) return set.dict.Remove(val)
} }
// Has returns true if the val exists in the set
func (set *Set) Has(val string) bool { func (set *Set) Has(val string) bool {
_, exists := set.dict.Get(val) _, exists := set.dict.Get(val)
return exists return exists
} }
// Len returns number of members in the set
func (set *Set) Len() int { func (set *Set) Len() int {
return set.dict.Len() return set.dict.Len()
} }
// ToSlice convert set to []string
func (set *Set) ToSlice() []string { func (set *Set) ToSlice() []string {
slice := make([]string, set.Len()) slice := make([]string, set.Len())
i := 0 i := 0
@@ -55,12 +56,14 @@ func (set *Set) ToSlice() []string {
return slice return slice
} }
// ForEach visits each member in the set
func (set *Set) ForEach(consumer func(member string) bool) { func (set *Set) ForEach(consumer func(member string) bool) {
set.dict.ForEach(func(key string, val interface{}) bool { set.dict.ForEach(func(key string, val interface{}) bool {
return consumer(key) return consumer(key)
}) })
} }
// Intersect intersects two sets
func (set *Set) Intersect(another *Set) *Set { func (set *Set) Intersect(another *Set) *Set {
if set == nil { if set == nil {
panic("set is nil") panic("set is nil")
@@ -76,6 +79,7 @@ func (set *Set) Intersect(another *Set) *Set {
return result return result
} }
// Union adds two sets
func (set *Set) Union(another *Set) *Set { func (set *Set) Union(another *Set) *Set {
if set == nil { if set == nil {
panic("set is nil") panic("set is nil")
@@ -92,6 +96,7 @@ func (set *Set) Union(another *Set) *Set {
return result return result
} }
// Diff subtracts two sets
func (set *Set) Diff(another *Set) *Set { func (set *Set) Diff(another *Set) *Set {
if set == nil { if set == nil {
panic("set is nil") panic("set is nil")
@@ -107,10 +112,12 @@ func (set *Set) Diff(another *Set) *Set {
return result return result
} }
// RandomMembers randomly returns keys of the given number, may contain duplicated key
func (set *Set) RandomMembers(limit int) []string { func (set *Set) RandomMembers(limit int) []string {
return set.dict.RandomKeys(limit) return set.dict.RandomKeys(limit)
} }
// RandomDistinctMembers randomly returns keys of the given number, won't contain duplicated key
func (set *Set) RandomDistinctMembers(limit int) []string { func (set *Set) RandomDistinctMembers(limit int) []string {
return set.dict.RandomDistinctKeys(limit) return set.dict.RandomDistinctKeys(limit)
} }

View File

@@ -18,6 +18,7 @@ const (
positiveInf int8 = 1 positiveInf int8 = 1
) )
// ScoreBorder represents range of a float value, including: <, <=, >, >=, +inf, -inf
type ScoreBorder struct { type ScoreBorder struct {
Inf int8 Inf int8
Value float64 Value float64
@@ -34,9 +35,8 @@ func (border *ScoreBorder) greater(value float64) bool {
} }
if border.Exclude { if border.Exclude {
return border.Value > value return border.Value > value
} else {
return border.Value >= value
} }
return border.Value >= value
} }
func (border *ScoreBorder) less(value float64) bool { func (border *ScoreBorder) less(value float64) bool {
@@ -47,9 +47,8 @@ func (border *ScoreBorder) less(value float64) bool {
} }
if border.Exclude { if border.Exclude {
return border.Value < value return border.Value < value
} else {
return border.Value <= value
} }
return border.Value <= value
} }
var positiveInfBorder = &ScoreBorder{ var positiveInfBorder = &ScoreBorder{
@@ -60,6 +59,7 @@ var negativeInfBorder = &ScoreBorder{
Inf: negativeInf, Inf: negativeInf,
} }
// ParseScoreBorder creates ScoreBorder from redis arguments
func ParseScoreBorder(s string) (*ScoreBorder, error) { func ParseScoreBorder(s string) (*ScoreBorder, error) {
if s == "inf" || s == "+inf" { if s == "inf" || s == "+inf" {
return positiveInfBorder, nil return positiveInfBorder, nil
@@ -77,15 +77,14 @@ func ParseScoreBorder(s string) (*ScoreBorder, error) {
Value: value, Value: value,
Exclude: true, Exclude: true,
}, nil }, nil
} else {
value, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, errors.New("ERR min or max is not a float")
}
return &ScoreBorder{
Inf: 0,
Value: value,
Exclude: false,
}, nil
} }
value, err := strconv.ParseFloat(s, 64)
if err != nil {
return nil, errors.New("ERR min or max is not a float")
}
return &ScoreBorder{
Inf: 0,
Value: value,
Exclude: false,
}, nil
} }

View File

@@ -6,32 +6,33 @@ const (
maxLevel = 16 maxLevel = 16
) )
// Element is a key-score pair
type Element struct { type Element struct {
Member string Member string
Score float64 Score float64
} }
// level aspect of a Node // Level aspect of a node
type Level struct { type Level struct {
forward *Node // forward node has greater score forward *node // forward node has greater score
span int64 span int64
} }
type Node struct { type node struct {
Element Element
backward *Node backward *node
level []*Level // level[0] is base level level []*Level // level[0] is base level
} }
type skiplist struct { type skiplist struct {
header *Node header *node
tail *Node tail *node
length int64 length int64
level int16 level int16
} }
func makeNode(level int16, score float64, member string) *Node { func makeNode(level int16, score float64, member string) *node {
n := &Node{ n := &node{
Element: Element{ Element: Element{
Score: score, Score: score,
Member: member, Member: member,
@@ -62,8 +63,8 @@ func randomLevel() int16 {
return maxLevel return maxLevel
} }
func (skiplist *skiplist) insert(member string, score float64) *Node { func (skiplist *skiplist) insert(member string, score float64) *node {
update := make([]*Node, maxLevel) // link new node with node in `update` update := make([]*node, maxLevel) // link new node with node in `update`
rank := make([]int64, maxLevel) rank := make([]int64, maxLevel)
// find position to insert // find position to insert
@@ -132,7 +133,7 @@ func (skiplist *skiplist) insert(member string, score float64) *Node {
* param node: node to delete * param node: node to delete
* param update: backward node (of target) * param update: backward node (of target)
*/ */
func (skiplist *skiplist) removeNode(node *Node, update []*Node) { func (skiplist *skiplist) removeNode(node *node, update []*node) {
for i := int16(0); i < skiplist.level; i++ { for i := int16(0); i < skiplist.level; i++ {
if update[i].level[i].forward == node { if update[i].level[i].forward == node {
update[i].level[i].span += node.level[i].span - 1 update[i].level[i].span += node.level[i].span - 1
@@ -160,7 +161,7 @@ func (skiplist *skiplist) remove(member string, score float64) bool {
* find backward node (of target) or last node of each level * find backward node (of target) or last node of each level
* their forward need to be updated * their forward need to be updated
*/ */
update := make([]*Node, maxLevel) update := make([]*node, maxLevel)
node := skiplist.header node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- { for i := skiplist.level - 1; i >= 0; i-- {
for node.level[i].forward != nil && for node.level[i].forward != nil &&
@@ -206,7 +207,7 @@ func (skiplist *skiplist) getRank(member string, score float64) int64 {
/* /*
* 1-based rank * 1-based rank
*/ */
func (skiplist *skiplist) getByRank(rank int64) *Node { func (skiplist *skiplist) getByRank(rank int64) *node {
var i int64 = 0 var i int64 = 0
n := skiplist.header n := skiplist.header
// scan from top level // scan from top level
@@ -240,7 +241,7 @@ func (skiplist *skiplist) hasInRange(min *ScoreBorder, max *ScoreBorder) bool {
return true return true
} }
func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node { func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorder) *node {
if !skiplist.hasInRange(min, max) { if !skiplist.hasInRange(min, max) {
return nil return nil
} }
@@ -260,7 +261,7 @@ func (skiplist *skiplist) getFirstInScoreRange(min *ScoreBorder, max *ScoreBorde
return n return n
} }
func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *Node { func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder) *node {
if !skiplist.hasInRange(min, max) { if !skiplist.hasInRange(min, max) {
return nil return nil
} }
@@ -281,7 +282,7 @@ func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder
* return removed elements * return removed elements
*/ */
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder) (removed []*Element) { func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder) (removed []*Element) {
update := make([]*Node, maxLevel) update := make([]*node, maxLevel)
removed = make([]*Element, 0) removed = make([]*Element, 0)
// find backward nodes (of target range) or last node of each level // find backward nodes (of target range) or last node of each level
node := skiplist.header node := skiplist.header
@@ -315,7 +316,7 @@ func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)
// 1-based rank, including start, exclude stop // 1-based rank, including start, exclude stop
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64) (removed []*Element) { func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64) (removed []*Element) {
var i int64 = 0 // rank of iterator var i int64 = 0 // rank of iterator
update := make([]*Node, maxLevel) update := make([]*node, maxLevel)
removed = make([]*Element, 0) removed = make([]*Element, 0)
// scan from top level // scan from top level

View File

@@ -4,11 +4,13 @@ import (
"strconv" "strconv"
) )
// SortedSet is a set which keys sorted by bound score
type SortedSet struct { type SortedSet struct {
dict map[string]*Element dict map[string]*Element
skiplist *skiplist skiplist *skiplist
} }
// Make makes a new SortedSet
func Make() *SortedSet { func Make() *SortedSet {
return &SortedSet{ return &SortedSet{
dict: make(map[string]*Element), dict: make(map[string]*Element),
@@ -16,9 +18,7 @@ func Make() *SortedSet {
} }
} }
/* // Add puts member into set, and returns whether has inserted new node
* return: has inserted new node
*/
func (sortedSet *SortedSet) Add(member string, score float64) bool { func (sortedSet *SortedSet) Add(member string, score float64) bool {
element, ok := sortedSet.dict[member] element, ok := sortedSet.dict[member]
sortedSet.dict[member] = &Element{ sortedSet.dict[member] = &Element{
@@ -31,16 +31,17 @@ func (sortedSet *SortedSet) Add(member string, score float64) bool {
sortedSet.skiplist.insert(member, score) sortedSet.skiplist.insert(member, score)
} }
return false return false
} else {
sortedSet.skiplist.insert(member, score)
return true
} }
sortedSet.skiplist.insert(member, score)
return true
} }
// Len returns number of members in set
func (sortedSet *SortedSet) Len() int64 { func (sortedSet *SortedSet) Len() int64 {
return int64(len(sortedSet.dict)) return int64(len(sortedSet.dict))
} }
// Get returns the given member
func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) { func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) {
element, ok = sortedSet.dict[member] element, ok = sortedSet.dict[member]
if !ok { if !ok {
@@ -49,6 +50,7 @@ func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) {
return element, true return element, true
} }
// Remove removes the given member from set
func (sortedSet *SortedSet) Remove(member string) bool { func (sortedSet *SortedSet) Remove(member string) bool {
v, ok := sortedSet.dict[member] v, ok := sortedSet.dict[member]
if ok { if ok {
@@ -59,9 +61,7 @@ func (sortedSet *SortedSet) Remove(member string) bool {
return false return false
} }
/** // GetRank returns the rank of the given member, sort by ascending order, rank starts from 0
* get 0-based rank
*/
func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) { func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) {
element, ok := sortedSet.dict[member] element, ok := sortedSet.dict[member]
if !ok { if !ok {
@@ -76,9 +76,7 @@ func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) {
return r return r
} }
/** // ForEach visits each member which rank within [start, stop), sort by ascending order, rank starts from 0
* traverse [start, stop), 0-based rank
*/
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element) bool) { func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element) bool) {
size := int64(sortedSet.Len()) size := int64(sortedSet.Len())
if start < 0 || start >= size { if start < 0 || start >= size {
@@ -89,7 +87,7 @@ func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer
} }
// find start node // find start node
var node *Node var node *node
if desc { if desc {
node = sortedSet.skiplist.tail node = sortedSet.skiplist.tail
if start > 0 { if start > 0 {
@@ -115,10 +113,7 @@ func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer
} }
} }
/** // Range returns members which rank within [start, stop), sort by ascending order, rank starts from 0
* return [start, stop), 0-based rank
* assert start in [0, size), stop in [start, size]
*/
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element { func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element {
sliceSize := int(stop - start) sliceSize := int(stop - start)
slice := make([]*Element, sliceSize) slice := make([]*Element, sliceSize)
@@ -131,6 +126,7 @@ func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element
return slice return slice
} }
// Count returns the number of members which score within the given border
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 { func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 {
var i int64 = 0 var i int64 = 0
// ascending order // ascending order
@@ -152,9 +148,10 @@ func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 {
return i return i
} }
// ForEachByScore visits members which score within the given border
func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) { func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool, consumer func(element *Element) bool) {
// find start node // find start node
var node *Node var node *node
if desc { if desc {
node = sortedSet.skiplist.getLastInScoreRange(min, max) node = sortedSet.skiplist.getLastInScoreRange(min, max)
} else { } else {
@@ -191,9 +188,8 @@ func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, o
} }
} }
/* // RangeByScore returns members which score within the given border
* param limit: <0 means no limit // param limit: <0 means no limit
*/
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool) []*Element { func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool) []*Element {
if limit == 0 || offset < 0 { if limit == 0 || offset < 0 {
return make([]*Element, 0) return make([]*Element, 0)
@@ -206,6 +202,7 @@ func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, off
return slice return slice
} }
// RemoveByScore removes members which score within the given border
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) int64 { func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) int64 {
removed := sortedSet.skiplist.RemoveRangeByScore(min, max) removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
for _, element := range removed { for _, element := range removed {
@@ -214,9 +211,8 @@ func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) in
return int64(len(removed)) return int64(len(removed))
} }
/* // RemoveByRank removes member ranking within [start, stop)
* 0-based rank, [start, stop) // sort by ascending order and rank starts from 0
*/
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 { func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 {
removed := sortedSet.skiplist.RemoveRangeByRank(start+1, stop+1) removed := sortedSet.skiplist.RemoveRangeByRank(start+1, stop+1)
for _, element := range removed { for _, element := range removed {

View File

@@ -1,5 +1,6 @@
package utils package utils
// Equals check whether the given value is equal
func Equals(a interface{}, b interface{}) bool { func Equals(a interface{}, b interface{}) bool {
sliceA, okA := a.([]byte) sliceA, okA := a.([]byte)
sliceB, okB := b.([]byte) sliceB, okB := b.([]byte)
@@ -9,6 +10,7 @@ func Equals(a interface{}, b interface{}) bool {
return a == b return a == b
} }
// BytesEquals check whether the given bytes is equal
func BytesEquals(a []byte, b []byte) bool { func BytesEquals(a []byte, b []byte) bool {
if (a == nil && b != nil) || (a != nil && b == nil) { if (a == nil && b != nil) || (a != nil && b == nil) {
return false return false

View File

@@ -96,6 +96,9 @@ func HGet(db *DB, args [][]byte) redis.Reply {
key := string(args[0]) key := string(args[0])
field := string(args[1]) field := string(args[1])
db.RLock(key)
defer db.RUnLock(key)
// get entity // get entity
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
@@ -122,6 +125,9 @@ func HExists(db *DB, args [][]byte) redis.Reply {
key := string(args[0]) key := string(args[0])
field := string(args[1]) field := string(args[1])
db.RLock(key)
defer db.RUnLock(key)
// get entity // get entity
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
@@ -186,6 +192,9 @@ func HLen(db *DB, args [][]byte) redis.Reply {
} }
key := string(args[0]) key := string(args[0])
db.RLock(key)
defer db.RUnLock(key)
dict, errReply := db.getAsDict(key) dict, errReply := db.getAsDict(key)
if errReply != nil { if errReply != nil {
return errReply return errReply

View File

@@ -2,6 +2,7 @@ package db
import "github.com/hdt3213/godis/interface/redis" import "github.com/hdt3213/godis/interface/redis"
// DB is the interface for redis style storage engine
type DB interface { type DB interface {
Exec(client redis.Connection, args [][]byte) redis.Reply Exec(client redis.Connection, args [][]byte) redis.Reply
AfterClientClose(c redis.Connection) AfterClientClose(c redis.Connection)

View File

@@ -1,5 +1,6 @@
package redis package redis
// Connection represents a connection with redis client
type Connection interface { type Connection interface {
Write([]byte) error Write([]byte) error
SetPassword(string) SetPassword(string)

View File

@@ -1,5 +1,6 @@
package redis package redis
// Reply is the interface of redis serialization protocol message
type Reply interface { type Reply interface {
ToBytes() []byte ToBytes() []byte
} }

View File

@@ -5,6 +5,7 @@ import (
"net" "net"
) )
// HandleFunc represents application handler function
type HandleFunc func(ctx context.Context, conn net.Conn) type HandleFunc func(ctx context.Context, conn net.Conn)
// Handler represents application server over tcp // Handler represents application server over tcp

View File

@@ -7,8 +7,10 @@ import (
"strings" "strings"
) )
// HashFunc defines function to generate hash code
type HashFunc func(data []byte) uint32 type HashFunc func(data []byte) uint32
// Map stores nodes and you can pick node from Map
type Map struct { type Map struct {
hashFunc HashFunc hashFunc HashFunc
replicas int replicas int
@@ -16,6 +18,7 @@ type Map struct {
hashMap map[int]string hashMap map[int]string
} }
// New creates a new Map
func New(replicas int, fn HashFunc) *Map { func New(replicas int, fn HashFunc) *Map {
m := &Map{ m := &Map{
replicas: replicas, replicas: replicas,
@@ -28,10 +31,12 @@ func New(replicas int, fn HashFunc) *Map {
return m return m
} }
// IsEmpty returns if there is no node in Map
func (m *Map) IsEmpty() bool { func (m *Map) IsEmpty() bool {
return len(m.keys) == 0 return len(m.keys) == 0
} }
// AddNode add the given nodes into consistent hash circle
func (m *Map) AddNode(keys ...string) { func (m *Map) AddNode(keys ...string) {
for _, key := range keys { for _, key := range keys {
if key == "" { if key == "" {

View File

@@ -1,62 +0,0 @@
package files
import (
"fmt"
"os"
)
func CheckNotExist(src string) bool {
_, err := os.Stat(src)
return os.IsNotExist(err)
}
func CheckPermission(src string) bool {
_, err := os.Stat(src)
return os.IsPermission(err)
}
func IsNotExistMkDir(src string) error {
if notExist := CheckNotExist(src); notExist == true {
if err := MkDir(src); err != nil {
return err
}
}
return nil
}
func MkDir(src string) error {
err := os.MkdirAll(src, os.ModePerm)
if err != nil {
return err
}
return nil
}
func Open(name string, flag int, perm os.FileMode) (*os.File, error) {
f, err := os.OpenFile(name, flag, perm)
if err != nil {
return nil, err
}
return f, nil
}
func MustOpen(fileName, dir string) (*os.File, error) {
perm := CheckPermission(dir)
if perm == true {
return nil, fmt.Errorf("permission denied dir: %s", dir)
}
err := IsNotExistMkDir(dir)
if err != nil {
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
}
f, err := Open(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return nil, fmt.Errorf("fail to open file, err: %s", err)
}
return f, nil
}

View File

@@ -50,6 +50,7 @@ func encode0(latitude, longitude float64, bitSize uint) ([]byte, [2][2]float64)
return hash.Bytes(), box return hash.Bytes(), box
} }
// Encode converts latitude and longitude to uint64 geohash code
func Encode(latitude, longitude float64) uint64 { func Encode(latitude, longitude float64) uint64 {
buf, _ := encode0(latitude, longitude, defaultBitSize) buf, _ := encode0(latitude, longitude, defaultBitSize)
return binary.BigEndian.Uint64(buf) return binary.BigEndian.Uint64(buf)
@@ -77,6 +78,7 @@ func decode0(hash []byte) [][]float64 {
return box return box
} }
// Decode converts uint64 geohash code to latitude and longitude
func Decode(code uint64) (float64, float64) { func Decode(code uint64) (float64, float64) {
buf := make([]byte, 8) buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, code) binary.BigEndian.PutUint64(buf, code)
@@ -86,10 +88,12 @@ func Decode(code uint64) (float64, float64) {
return lat, lng return lat, lng
} }
// ToString converts bytes geohash code to base32 string
func ToString(buf []byte) string { func ToString(buf []byte) string {
return enc.EncodeToString(buf) return enc.EncodeToString(buf)
} }
// ToInt converts bytes geohash code to uint64 code
func ToInt(buf []byte) uint64 { func ToInt(buf []byte) uint64 {
// padding // padding
if len(buf) < 8 { if len(buf) < 8 {
@@ -100,6 +104,7 @@ func ToInt(buf []byte) uint64 {
return binary.BigEndian.Uint64(buf) return binary.BigEndian.Uint64(buf)
} }
// FromInt converts uint64 geohash code to bytes
func FromInt(code uint64) []byte { func FromInt(code uint64) []byte {
buf := make([]byte, 8) buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, code) binary.BigEndian.PutUint64(buf, code)

View File

@@ -8,13 +8,13 @@ import (
func TestToRange(t *testing.T) { func TestToRange(t *testing.T) {
neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00} neighbor := []byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}
range_ := ToRange(neighbor, 36) geoRange := toRange(neighbor, 36)
expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00}) expectedLower := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xE0, 0x00, 0x00, 0x00})
expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00}) expectedUpper := ToInt([]byte{0x00, 0x00, 0x00, 0x00, 0xF0, 0x00, 0x00, 0x00})
if expectedLower != range_[0] { if expectedLower != geoRange[0] {
t.Error("incorrect lower") t.Error("incorrect lower")
} }
if expectedUpper != range_[1] { if expectedUpper != geoRange[1] {
t.Error("incorrect upper") t.Error("incorrect upper")
} }
} }

View File

@@ -3,39 +3,18 @@ package geohash
import "math" import "math"
const ( const (
DR = math.Pi / 180.0 dr = math.Pi / 180.0
EarthRadius = 6372797.560856 earthRadius = 6372797.560856
MercatorMax = 20037726.37 // pi * EarthRadius mercatorMax = 20037726.37 // pi * earthRadius
MercatorMin = -20037726.37 mercatorMin = -20037726.37
) )
func degRad(ang float64) float64 { func degRad(ang float64) float64 {
return ang * DR return ang * dr
} }
func radDeg(ang float64) float64 { func radDeg(ang float64) float64 {
return ang / DR return ang / dr
}
func getBoundingBox(latitude float64, longitude float64, radiusMeters float64) (
minLat, maxLat, minLng, maxLng float64) {
minLng = longitude - radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
if minLng < -180 {
minLng = -180
}
maxLng = longitude + radDeg(radiusMeters/EarthRadius/math.Cos(degRad(latitude)))
if maxLng > 180 {
maxLng = 180
}
minLat = latitude - radDeg(radiusMeters/EarthRadius)
if minLat < -90 {
minLat = -90
}
maxLat = latitude + radDeg(radiusMeters/EarthRadius)
if maxLat > 90 {
maxLat = 90
}
return
} }
func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint { func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint {
@@ -43,7 +22,7 @@ func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint {
return defaultBitSize - 1 return defaultBitSize - 1
} }
var precision uint = 1 var precision uint = 1
for radiusMeters < MercatorMax { for radiusMeters < mercatorMax {
radiusMeters *= 2 radiusMeters *= 2
precision++ precision++
} }
@@ -64,16 +43,18 @@ func estimatePrecisionByRadius(radiusMeters float64, latitude float64) uint {
return precision*2 - 1 return precision*2 - 1
} }
// Distance computes the distance between two given coordinates in meter
func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 { func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 {
radLat1 := degRad(latitude1) radLat1 := degRad(latitude1)
radLat2 := degRad(latitude2) radLat2 := degRad(latitude2)
a := radLat1 - radLat2 a := radLat1 - radLat2
b := degRad(longitude1) - degRad(longitude2) b := degRad(longitude1) - degRad(longitude2)
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2)+ return 2 * earthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2)+
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2))) math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
} }
func ToRange(scope []byte, precision uint) [2]uint64 { // toRange covert geohash prefix to uint64 range
func toRange(scope []byte, precision uint) [2]uint64 {
lower := ToInt(scope) lower := ToInt(scope)
radius := uint64(1 << (64 - precision)) radius := uint64(1 << (64 - precision))
upper := lower + radius upper := lower + radius
@@ -100,6 +81,7 @@ func ensureValidLng(lng float64) float64 {
return lng return lng
} }
// GetNeighbours returns geohash code of blocks within radiusMeters to the given coordinate
func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 { func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 {
precision := estimatePrecisionByRadius(radiusMeters, latitude) precision := estimatePrecisionByRadius(radiusMeters, latitude)
@@ -115,22 +97,22 @@ func GetNeighbours(latitude, longitude, radiusMeters float64) [][2]uint64 {
var result [10][2]uint64 var result [10][2]uint64
leftUpper, _ := encode0(maxLat, minLng, precision) leftUpper, _ := encode0(maxLat, minLng, precision)
result[1] = ToRange(leftUpper, precision) result[1] = toRange(leftUpper, precision)
upper, _ := encode0(maxLat, centerLng, precision) upper, _ := encode0(maxLat, centerLng, precision)
result[2] = ToRange(upper, precision) result[2] = toRange(upper, precision)
rightUpper, _ := encode0(maxLat, maxLng, precision) rightUpper, _ := encode0(maxLat, maxLng, precision)
result[3] = ToRange(rightUpper, precision) result[3] = toRange(rightUpper, precision)
left, _ := encode0(centerLat, minLng, precision) left, _ := encode0(centerLat, minLng, precision)
result[4] = ToRange(left, precision) result[4] = toRange(left, precision)
result[5] = ToRange(center, precision) result[5] = toRange(center, precision)
right, _ := encode0(centerLat, maxLng, precision) right, _ := encode0(centerLat, maxLng, precision)
result[6] = ToRange(right, precision) result[6] = toRange(right, precision)
leftDown, _ := encode0(minLat, minLng, precision) leftDown, _ := encode0(minLat, minLng, precision)
result[7] = ToRange(leftDown, precision) result[7] = toRange(leftDown, precision)
down, _ := encode0(minLat, centerLng, precision) down, _ := encode0(minLat, centerLng, precision)
result[8] = ToRange(down, precision) result[8] = toRange(down, precision)
rightDown, _ := encode0(minLat, maxLng, precision) rightDown, _ := encode0(minLat, maxLng, precision)
result[9] = ToRange(rightDown, precision) result[9] = toRange(rightDown, precision)
return result[1:] return result[1:]
} }

View File

@@ -8,41 +8,44 @@ import (
) )
const ( const (
// Epoch is set to the twitter snowflake epoch of Nov 04 2010 01:42:54 UTC in milliseconds // epoch0 is set to the twitter snowflake epoch of Nov 04 2010 01:42:54 UTC in milliseconds
// You may customize this to set a different epoch for your application. // You may customize this to set a different epoch for your application.
Epoch int64 = 1288834974657 epoch0 int64 = 1288834974657
maxSequence int64 = -1 ^ (-1 << uint64(nodeLeft)) maxSequence int64 = -1 ^ (-1 << uint64(nodeLeft))
timeLeft uint8 = 22 timeLeft uint8 = 22
nodeLeft uint8 = 10 nodeLeft uint8 = 10
nodeMask int64 = -1 ^ (-1 << uint64(timeLeft-nodeLeft)) nodeMask int64 = -1 ^ (-1 << uint64(timeLeft-nodeLeft))
) )
type IdGenerator struct { // IDGenerator generates unique uint64 ID using snowflake algorithm
type IDGenerator struct {
mu *sync.Mutex mu *sync.Mutex
lastStamp int64 lastStamp int64
workerId int64 nodeID int64
sequence int64 sequence int64
epoch time.Time epoch time.Time
} }
func MakeGenerator(node string) *IdGenerator { // MakeGenerator creates a new IDGenerator
func MakeGenerator(node string) *IDGenerator {
fnv64 := fnv.New64() fnv64 := fnv.New64()
_, _ = fnv64.Write([]byte(node)) _, _ = fnv64.Write([]byte(node))
nodeId := int64(fnv64.Sum64()) & nodeMask nodeID := int64(fnv64.Sum64()) & nodeMask
var curTime = time.Now() var curTime = time.Now()
epoch := curTime.Add(time.Unix(Epoch/1000, (Epoch%1000)*1000000).Sub(curTime)) epoch := curTime.Add(time.Unix(epoch0/1000, (epoch0%1000)*1000000).Sub(curTime))
return &IdGenerator{ return &IDGenerator{
mu: &sync.Mutex{}, mu: &sync.Mutex{},
lastStamp: -1, lastStamp: -1,
workerId: nodeId, nodeID: nodeID,
sequence: 1, sequence: 1,
epoch: epoch, epoch: epoch,
} }
} }
func (w *IdGenerator) NextId() int64 { // NextID returns next unique ID
func (w *IDGenerator) NextID() int64 {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -61,7 +64,7 @@ func (w *IdGenerator) NextId() int64 {
w.sequence = 0 w.sequence = 0
} }
w.lastStamp = timestamp w.lastStamp = timestamp
id := (timestamp << timeLeft) | (w.workerId << nodeLeft) | w.sequence id := (timestamp << timeLeft) | (w.nodeID << nodeLeft) | w.sequence
//fmt.Printf("%d %d %d\n", timestamp, w.sequence, id) //fmt.Printf("%d %d %d\n", timestamp, w.sequence, id)
return id return id
} }

View File

@@ -7,7 +7,7 @@ func TestMGenerator(t *testing.T) {
ids := make(map[int64]struct{}) ids := make(map[int64]struct{})
size := int(1e6) size := int(1e6)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
id := gen.NextId() id := gen.NextID()
_, ok := ids[id] _, ok := ids[id]
if ok { if ok {
t.Errorf("duplicated id: %d, time: %d, seq: %d", id, gen.lastStamp, gen.sequence) t.Errorf("duplicated id: %d, time: %d, seq: %d", id, gen.lastStamp, gen.sequence)

53
lib/logger/files.go Normal file
View File

@@ -0,0 +1,53 @@
package logger
import (
"fmt"
"os"
)
func checkNotExist(src string) bool {
_, err := os.Stat(src)
return os.IsNotExist(err)
}
func checkPermission(src string) bool {
_, err := os.Stat(src)
return os.IsPermission(err)
}
func isNotExistMkDir(src string) error {
if notExist := checkNotExist(src); notExist == true {
if err := mkDir(src); err != nil {
return err
}
}
return nil
}
func mkDir(src string) error {
err := os.MkdirAll(src, os.ModePerm)
if err != nil {
return err
}
return nil
}
func mustOpen(fileName, dir string) (*os.File, error) {
perm := checkPermission(dir)
if perm == true {
return nil, fmt.Errorf("permission denied dir: %s", dir)
}
err := isNotExistMkDir(dir)
if err != nil {
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
}
f, err := os.OpenFile(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return nil, fmt.Errorf("fail to open file, err: %s", err)
}
return f, nil
}

View File

@@ -2,7 +2,6 @@ package logger
import ( import (
"fmt" "fmt"
"github.com/hdt3213/godis/lib/files"
"io" "io"
"log" "log"
"os" "os"
@@ -11,6 +10,7 @@ import (
"time" "time"
) )
// Settings stores config for logger
type Settings struct { type Settings struct {
Path string `yaml:"path"` Path string `yaml:"path"`
Name string `yaml:"name"` Name string `yaml:"name"`
@@ -19,18 +19,19 @@ type Settings struct {
} }
var ( var (
F *os.File logFile *os.File
DefaultPrefix = "" defaultPrefix = ""
DefaultCallerDepth = 2 defaultCallerDepth = 2
logger *log.Logger logger *log.Logger
logPrefix = "" logPrefix = ""
levelFlags = []string{"DEBUG", "INFO", "WARN", "ERROR", "FATAL"} levelFlags = []string{"DEBUG", "INFO", "WARN", "ERROR", "FATAL"}
) )
type Level int type logLevel int
// log levels
const ( const (
DEBUG Level = iota DEBUG logLevel = iota
INFO INFO
WARNING WARNING
ERROR ERROR
@@ -40,9 +41,10 @@ const (
const flags = log.LstdFlags const flags = log.LstdFlags
func init() { func init() {
logger = log.New(os.Stdout, DefaultPrefix, flags) logger = log.New(os.Stdout, defaultPrefix, flags)
} }
// Setup initializes logger
func Setup(settings *Settings) { func Setup(settings *Settings) {
var err error var err error
dir := settings.Path dir := settings.Path
@@ -51,17 +53,17 @@ func Setup(settings *Settings) {
time.Now().Format(settings.TimeFormat), time.Now().Format(settings.TimeFormat),
settings.Ext) settings.Ext)
logFile, err := files.MustOpen(fileName, dir) logFile, err := mustOpen(fileName, dir)
if err != nil { if err != nil {
log.Fatalf("logging.Setup err: %s", err) log.Fatalf("logging.Setup err: %s", err)
} }
mw := io.MultiWriter(os.Stdout, logFile) mw := io.MultiWriter(os.Stdout, logFile)
logger = log.New(mw, DefaultPrefix, flags) logger = log.New(mw, defaultPrefix, flags)
} }
func setPrefix(level Level) { func setPrefix(level logLevel) {
_, file, line, ok := runtime.Caller(DefaultCallerDepth) _, file, line, ok := runtime.Caller(defaultCallerDepth)
if ok { if ok {
logPrefix = fmt.Sprintf("[%s][%s:%d] ", levelFlags[level], filepath.Base(file), line) logPrefix = fmt.Sprintf("[%s][%s:%d] ", levelFlags[level], filepath.Base(file), line)
} else { } else {
@@ -71,26 +73,31 @@ func setPrefix(level Level) {
logger.SetPrefix(logPrefix) logger.SetPrefix(logPrefix)
} }
// Debug prints debug log
func Debug(v ...interface{}) { func Debug(v ...interface{}) {
setPrefix(DEBUG) setPrefix(DEBUG)
logger.Println(v...) logger.Println(v...)
} }
// Info prints normal log
func Info(v ...interface{}) { func Info(v ...interface{}) {
setPrefix(INFO) setPrefix(INFO)
logger.Println(v...) logger.Println(v...)
} }
// Warn prints warning log
func Warn(v ...interface{}) { func Warn(v ...interface{}) {
setPrefix(WARNING) setPrefix(WARNING)
logger.Println(v...) logger.Println(v...)
} }
// Error prints error log
func Error(v ...interface{}) { func Error(v ...interface{}) {
setPrefix(ERROR) setPrefix(ERROR)
logger.Println(v...) logger.Println(v...)
} }
// Fatal prints error log then stop the program
func Fatal(v ...interface{}) { func Fatal(v ...interface{}) {
setPrefix(FATAL) setPrefix(FATAL)
logger.Fatalln(v...) logger.Fatalln(v...)

View File

@@ -2,13 +2,16 @@ package atomic
import "sync/atomic" import "sync/atomic"
type AtomicBool uint32 // Boolean is a boolean value, all actions of it is atomic
type Boolean uint32
func (b *AtomicBool) Get() bool { // Get reads the value atomically
func (b *Boolean) Get() bool {
return atomic.LoadUint32((*uint32)(b)) != 0 return atomic.LoadUint32((*uint32)(b)) != 0
} }
func (b *AtomicBool) Set(v bool) { // Set writes the value atomically
func (b *Boolean) Set(v bool) {
if v { if v {
atomic.StoreUint32((*uint32)(b), 1) atomic.StoreUint32((*uint32)(b), 1)
} else { } else {

View File

@@ -5,23 +5,28 @@ import (
"time" "time"
) )
// Wait is similar with sync.WaitGroup which can wait with timeout
type Wait struct { type Wait struct {
wg sync.WaitGroup wg sync.WaitGroup
} }
// Add adds delta, which may be negative, to the WaitGroup counter.
func (w *Wait) Add(delta int) { func (w *Wait) Add(delta int) {
w.wg.Add(delta) w.wg.Add(delta)
} }
// Done decrements the WaitGroup counter by one
func (w *Wait) Done() { func (w *Wait) Done() {
w.wg.Done() w.wg.Done()
} }
// Wait blocks until the WaitGroup counter is zero.
func (w *Wait) Wait() { func (w *Wait) Wait() {
w.wg.Wait() w.wg.Wait()
} }
// return isTimeout // WaitWithTimeout blocks until the WaitGroup counter is zero or timeout
// returns true if timeout
func (w *Wait) WaitWithTimeout(timeout time.Duration) bool { func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
c := make(chan bool) c := make(chan bool)
go func() { go func() {

View File

@@ -8,14 +8,17 @@ func init() {
tw.Start() tw.Start()
} }
// Delay executes job after waiting the given duration
func Delay(duration time.Duration, key string, job func()) { func Delay(duration time.Duration, key string, job func()) {
tw.AddTimer(duration, key, job) tw.AddJob(duration, key, job)
} }
// At executes job at given time
func At(at time.Time, key string, job func()) { func At(at time.Time, key string, job func()) {
tw.AddTimer(at.Sub(time.Now()), key, job) tw.AddJob(at.Sub(time.Now()), key, job)
} }
// Cancel stops a pending job
func Cancel(key string) { func Cancel(key string) {
tw.RemoveTimer(key) tw.RemoveJob(key)
} }

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
) )
// TimeWheel can execute job after waiting given duration
type TimeWheel struct { type TimeWheel struct {
interval time.Duration interval time.Duration
ticker *time.Ticker ticker *time.Ticker
@@ -14,18 +15,19 @@ type TimeWheel struct {
timer map[string]int timer map[string]int
currentPos int currentPos int
slotNum int slotNum int
addTaskChannel chan Task addTaskChannel chan task
removeTaskChannel chan string removeTaskChannel chan string
stopChannel chan bool stopChannel chan bool
} }
type Task struct { type task struct {
delay time.Duration delay time.Duration
circle int circle int
key string key string
job func() job func()
} }
// New creates a new time wheel
func New(interval time.Duration, slotNum int) *TimeWheel { func New(interval time.Duration, slotNum int) *TimeWheel {
if interval <= 0 || slotNum <= 0 { if interval <= 0 || slotNum <= 0 {
return nil return nil
@@ -36,7 +38,7 @@ func New(interval time.Duration, slotNum int) *TimeWheel {
timer: make(map[string]int), timer: make(map[string]int),
currentPos: 0, currentPos: 0,
slotNum: slotNum, slotNum: slotNum,
addTaskChannel: make(chan Task), addTaskChannel: make(chan task),
removeTaskChannel: make(chan string), removeTaskChannel: make(chan string),
stopChannel: make(chan bool), stopChannel: make(chan bool),
} }
@@ -51,23 +53,28 @@ func (tw *TimeWheel) initSlots() {
} }
} }
// Start starts ticker for time wheel
func (tw *TimeWheel) Start() { func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(tw.interval) tw.ticker = time.NewTicker(tw.interval)
go tw.start() go tw.start()
} }
// Stop stops the time wheel
func (tw *TimeWheel) Stop() { func (tw *TimeWheel) Stop() {
tw.stopChannel <- true tw.stopChannel <- true
} }
func (tw *TimeWheel) AddTimer(delay time.Duration, key string, job func()) { // AddJob add new job into pending queue
func (tw *TimeWheel) AddJob(delay time.Duration, key string, job func()) {
if delay < 0 { if delay < 0 {
return return
} }
tw.addTaskChannel <- Task{delay: delay, key: key, job: job} tw.addTaskChannel <- task{delay: delay, key: key, job: job}
} }
func (tw *TimeWheel) RemoveTimer(key string) { // RemoveJob add remove job from pending queue
// if job is done or not found, then nothing happened
func (tw *TimeWheel) RemoveJob(key string) {
if key == "" { if key == "" {
return return
} }
@@ -102,7 +109,7 @@ func (tw *TimeWheel) tickHandler() {
func (tw *TimeWheel) scanAndRunTask(l *list.List) { func (tw *TimeWheel) scanAndRunTask(l *list.List) {
for e := l.Front(); e != nil; { for e := l.Front(); e != nil; {
task := e.Value.(*Task) task := e.Value.(*task)
if task.circle > 0 { if task.circle > 0 {
task.circle-- task.circle--
e = e.Next() e = e.Next()
@@ -127,7 +134,7 @@ func (tw *TimeWheel) scanAndRunTask(l *list.List) {
} }
} }
func (tw *TimeWheel) addTask(task *Task) { func (tw *TimeWheel) addTask(task *task) {
pos, circle := tw.getPositionAndCircle(task.delay) pos, circle := tw.getPositionAndCircle(task.delay)
task.circle = circle task.circle = circle
@@ -154,7 +161,7 @@ func (tw *TimeWheel) removeTask(key string) {
} }
l := tw.slots[position] l := tw.slots[position]
for e := l.Front(); e != nil; { for e := l.Front(); e != nil; {
task := e.Value.(*Task) task := e.Value.(*task)
if task.key == key { if task.key == key {
delete(tw.timer, task.key) delete(tw.timer, task.key)
l.Remove(e) l.Remove(e)

View File

@@ -1,5 +1,6 @@
package utils package utils
// ToBytesList convert strings to [][]byte
func ToBytesList(cmd ...string) [][]byte { func ToBytesList(cmd ...string) [][]byte {
args := make([][]byte, len(cmd)) args := make([][]byte, len(cmd))
for i, s := range cmd { for i, s := range cmd {

View File

@@ -5,12 +5,14 @@ import (
"io" "io"
) )
// LimitedReader implements io.Reader, but you can only read the given number of bytes
type LimitedReader struct { type LimitedReader struct {
src io.Reader src io.Reader
n int n int
limit int limit int
} }
// NewLimitedReader wraps an io.Reader to LimitedReader
func NewLimitedReader(src io.Reader, limit int) *LimitedReader { func NewLimitedReader(src io.Reader, limit int) *LimitedReader {
return &LimitedReader{ return &LimitedReader{
src: src, src: src,
@@ -18,6 +20,7 @@ func NewLimitedReader(src io.Reader, limit int) *LimitedReader {
} }
} }
// Read reads up to len(p) bytes into p. if meets EOF from src or reach limit, it returns EOF
func (r *LimitedReader) Read(p []byte) (n int, err error) { func (r *LimitedReader) Read(p []byte) (n int, err error) {
if r.src == nil { if r.src == nil {
return 0, errors.New("no data source") return 0, errors.New("no data source")

View File

@@ -4,6 +4,7 @@ import "math/rand"
var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
// RandString create a random string no longer than n
func RandString(n int) string { func RandString(n int) string {
b := make([]rune, n) b := make([]rune, n)
for i := range b { for i := range b {

View File

@@ -1,10 +1,10 @@
package wildcard package wildcard
const ( const (
normal = iota normal = iota
all // * all // *
any // ? any // ?
set_ // [] setSymbol // []
) )
type item struct { type item struct {
@@ -18,10 +18,12 @@ func (i *item) contains(c byte) bool {
return ok return ok
} }
// Pattern represents a wildcard pattern
type Pattern struct { type Pattern struct {
items []*item items []*item
} }
// CompilePattern convert wildcard string to Pattern
func CompilePattern(src string) *Pattern { func CompilePattern(src string) *Pattern {
items := make([]*item, 0) items := make([]*item, 0)
escape := false escape := false
@@ -48,7 +50,7 @@ func CompilePattern(src string) *Pattern {
} else if c == ']' { } else if c == ']' {
if inSet { if inSet {
inSet = false inSet = false
items = append(items, &item{typeCode: set_, set: set}) items = append(items, &item{typeCode: setSymbol, set: set})
} else { } else {
items = append(items, &item{typeCode: normal, character: c}) items = append(items, &item{typeCode: normal, character: c})
} }
@@ -65,6 +67,7 @@ func CompilePattern(src string) *Pattern {
} }
} }
// IsMatch returns whether the given string matches pattern
func (p *Pattern) IsMatch(s string) bool { func (p *Pattern) IsMatch(s string) bool {
if len(p.items) == 0 { if len(p.items) == 0 {
return len(s) == 0 return len(s) == 0
@@ -87,7 +90,7 @@ func (p *Pattern) IsMatch(s string) bool {
table[i][j] = table[i-1][j-1] && table[i][j] = table[i-1][j-1] &&
(p.items[j-1].typeCode == any || (p.items[j-1].typeCode == any ||
(p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) || (p.items[j-1].typeCode == normal && uint8(s[i-1]) == p.items[j-1].character) ||
(p.items[j-1].typeCode == set_ && p.items[j-1].contains(s[i-1]))) (p.items[j-1].typeCode == setSymbol && p.items[j-1].contains(s[i-1])))
} }
} }
} }

View File

@@ -5,6 +5,7 @@ import (
"github.com/hdt3213/godis/datastruct/lock" "github.com/hdt3213/godis/datastruct/lock"
) )
// Hub stores all subscribe relations
type Hub struct { type Hub struct {
// channel -> list(*Client) // channel -> list(*Client)
subs dict.Dict subs dict.Dict
@@ -12,6 +13,7 @@ type Hub struct {
subsLocker *lock.Locks subsLocker *lock.Locks
} }
// MakeHub creates new hub
func MakeHub() *Hub { func MakeHub() *Hub {
return &Hub{ return &Hub{
subs: dict.MakeConcurrent(4), subs: dict.MakeConcurrent(4),

View File

@@ -83,6 +83,7 @@ func Subscribe(hub *Hub, c redis.Connection, args [][]byte) redis.Reply {
return &reply.NoReply{} return &reply.NoReply{}
} }
// UnsubscribeAll removes the given connection from all subscribing channel
func UnsubscribeAll(hub *Hub, c redis.Connection) { func UnsubscribeAll(hub *Hub, c redis.Connection) {
channels := c.GetChannels() channels := c.GetChannels()
@@ -95,6 +96,7 @@ func UnsubscribeAll(hub *Hub, c redis.Connection) {
} }
// UnSubscribe removes the given connection from the given channel
func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply { func UnSubscribe(db *Hub, c redis.Connection, args [][]byte) redis.Reply {
var channels []string var channels []string
if len(args) > 0 { if len(args) > 0 {

View File

@@ -12,17 +12,19 @@ import (
"time" "time"
) )
// Client is a pipeline mode redis client
type Client struct { type Client struct {
conn net.Conn conn net.Conn
pendingReqs chan *Request // wait to send pendingReqs chan *request // wait to send
waitingReqs chan *Request // waiting response waitingReqs chan *request // waiting response
ticker *time.Ticker ticker *time.Ticker
addr string addr string
working *sync.WaitGroup // its counter presents unfinished requests(pending and waiting) working *sync.WaitGroup // its counter presents unfinished requests(pending and waiting)
} }
type Request struct { // request is a message sends to redis server
type request struct {
id uint64 id uint64
args [][]byte args [][]byte
reply redis.Reply reply redis.Reply
@@ -36,6 +38,7 @@ const (
maxWait = 3 * time.Second maxWait = 3 * time.Second
) )
// MakeClient creates a new client
func MakeClient(addr string) (*Client, error) { func MakeClient(addr string) (*Client, error) {
conn, err := net.Dial("tcp", addr) conn, err := net.Dial("tcp", addr)
if err != nil { if err != nil {
@@ -44,12 +47,13 @@ func MakeClient(addr string) (*Client, error) {
return &Client{ return &Client{
addr: addr, addr: addr,
conn: conn, conn: conn,
pendingReqs: make(chan *Request, chanSize), pendingReqs: make(chan *request, chanSize),
waitingReqs: make(chan *Request, chanSize), waitingReqs: make(chan *request, chanSize),
working: &sync.WaitGroup{}, working: &sync.WaitGroup{},
}, nil }, nil
} }
// Start starts asynchronous goroutines
func (client *Client) Start() { func (client *Client) Start() {
client.ticker = time.NewTicker(10 * time.Second) client.ticker = time.NewTicker(10 * time.Second)
go client.handleWrite() go client.handleWrite()
@@ -62,6 +66,7 @@ func (client *Client) Start() {
go client.heartbeat() go client.heartbeat()
} }
// Close stops asynchronous goroutines and close connection
func (client *Client) Close() { func (client *Client) Close() {
client.ticker.Stop() client.ticker.Stop()
// stop new request // stop new request
@@ -110,8 +115,9 @@ func (client *Client) handleWrite() {
} }
} }
// Send sends a request to redis server
func (client *Client) Send(args [][]byte) redis.Reply { func (client *Client) Send(args [][]byte) redis.Reply {
request := &Request{ request := &request{
args: args, args: args,
heartbeat: false, heartbeat: false,
waiting: &wait.Wait{}, waiting: &wait.Wait{},
@@ -131,7 +137,7 @@ func (client *Client) Send(args [][]byte) redis.Reply {
} }
func (client *Client) doHeartbeat() { func (client *Client) doHeartbeat() {
request := &Request{ request := &request{
args: [][]byte{[]byte("PING")}, args: [][]byte{[]byte("PING")},
heartbeat: true, heartbeat: true,
waiting: &wait.Wait{}, waiting: &wait.Wait{},
@@ -143,7 +149,7 @@ func (client *Client) doHeartbeat() {
request.waiting.WaitWithTimeout(maxWait) request.waiting.WaitWithTimeout(maxWait)
} }
func (client *Client) doRequest(req *Request) { func (client *Client) doRequest(req *request) {
if req == nil || len(req.args) == 0 { if req == nil || len(req.args) == 0 {
return return
} }

View File

@@ -100,28 +100,34 @@ func (c *Connection) GetChannels() []string {
return channels return channels
} }
// SetPassword stores password for authentication
func (c *Connection) SetPassword(password string) { func (c *Connection) SetPassword(password string) {
c.password = password c.password = password
} }
// GetPassword get password for authentication
func (c *Connection) GetPassword() string { func (c *Connection) GetPassword() string {
return c.password return c.password
} }
// FakeConn implements redis.Connection for test
type FakeConn struct { type FakeConn struct {
Connection Connection
buf bytes.Buffer buf bytes.Buffer
} }
// Write writes data to buffer
func (c *FakeConn) Write(b []byte) error { func (c *FakeConn) Write(b []byte) error {
c.buf.Write(b) c.buf.Write(b)
return nil return nil
} }
// Clean resets the buffer
func (c *FakeConn) Clean() { func (c *FakeConn) Clean() {
c.buf.Reset() c.buf.Reset()
} }
// Bytes returns written data
func (c *FakeConn) Bytes() []byte { func (c *FakeConn) Bytes() []byte {
return c.buf.Bytes() return c.buf.Bytes()
} }

View File

@@ -13,6 +13,7 @@ import (
"strings" "strings"
) )
// Payload stores redis.Reply or error
type Payload struct { type Payload struct {
Data redis.Reply Data redis.Reply
Err error Err error
@@ -71,13 +72,13 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
} }
close(ch) close(ch)
return return
} else { // protocol err, reset read state
ch <- &Payload{
Err: err,
}
state = readState{}
continue
} }
// protocol err, reset read state
ch <- &Payload{
Err: err,
}
state = readState{}
continue
} }
// parse line // parse line

View File

@@ -9,6 +9,7 @@ import (
"testing" "testing"
) )
// AssertIntReply checks if the given redis.Reply is the expected integer
func AssertIntReply(t *testing.T, actual redis.Reply, expected int) { func AssertIntReply(t *testing.T, actual redis.Reply, expected int) {
intResult, ok := actual.(*reply.IntReply) intResult, ok := actual.(*reply.IntReply)
if !ok { if !ok {
@@ -20,6 +21,7 @@ func AssertIntReply(t *testing.T, actual redis.Reply, expected int) {
} }
} }
// AssertBulkReply checks if the given redis.Reply is the expected string
func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) { func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) {
bulkReply, ok := actual.(*reply.BulkReply) bulkReply, ok := actual.(*reply.BulkReply)
if !ok { if !ok {
@@ -31,6 +33,7 @@ func AssertBulkReply(t *testing.T, actual redis.Reply, expected string) {
} }
} }
// AssertStatusReply checks if the given redis.Reply is the expected status
func AssertStatusReply(t *testing.T, actual redis.Reply, expected string) { func AssertStatusReply(t *testing.T, actual redis.Reply, expected string) {
statusReply, ok := actual.(*reply.StatusReply) statusReply, ok := actual.(*reply.StatusReply)
if !ok { if !ok {
@@ -47,6 +50,7 @@ func AssertStatusReply(t *testing.T, actual redis.Reply, expected string) {
} }
} }
// AssertErrReply checks if the given redis.Reply is the expected error
func AssertErrReply(t *testing.T, actual redis.Reply, expected string) { func AssertErrReply(t *testing.T, actual redis.Reply, expected string) {
errReply, ok := actual.(reply.ErrorReply) errReply, ok := actual.(reply.ErrorReply)
if !ok { if !ok {
@@ -62,6 +66,7 @@ func AssertErrReply(t *testing.T, actual redis.Reply, expected string) {
} }
} }
// AssertNotError checks if the given redis.Reply is not error reply
func AssertNotError(t *testing.T, result redis.Reply) { func AssertNotError(t *testing.T, result redis.Reply) {
if result == nil { if result == nil {
t.Errorf("result is nil %s", printStack()) t.Errorf("result is nil %s", printStack())
@@ -77,6 +82,7 @@ func AssertNotError(t *testing.T, result redis.Reply) {
} }
} }
// AssertNullBulk checks if the given redis.Reply is reply.NullBulkReply
func AssertNullBulk(t *testing.T, result redis.Reply) { func AssertNullBulk(t *testing.T, result redis.Reply) {
if result == nil { if result == nil {
t.Errorf("result is nil %s", printStack()) t.Errorf("result is nil %s", printStack())
@@ -93,6 +99,7 @@ func AssertNullBulk(t *testing.T, result redis.Reply) {
} }
} }
// AssertMultiBulkReply checks if the given redis.Reply has the expected content
func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) { func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) {
multiBulk, ok := actual.(*reply.MultiBulkReply) multiBulk, ok := actual.(*reply.MultiBulkReply)
if !ok { if !ok {
@@ -112,6 +119,7 @@ func AssertMultiBulkReply(t *testing.T, actual redis.Reply, expected []string) {
} }
} }
// AssertMultiBulkReplySize check if redis.Reply has expected length
func AssertMultiBulkReplySize(t *testing.T, actual redis.Reply, expected int) { func AssertMultiBulkReplySize(t *testing.T, actual redis.Reply, expected int) {
multiBulk, ok := actual.(*reply.MultiBulkReply) multiBulk, ok := actual.(*reply.MultiBulkReply)
if !ok { if !ok {

View File

@@ -1,50 +1,61 @@
package reply package reply
// PongReply is +PONG
type PongReply struct{} type PongReply struct{}
var PongBytes = []byte("+PONG\r\n") var pongBytes = []byte("+PONG\r\n")
// ToBytes marshal redis.Reply
func (r *PongReply) ToBytes() []byte { func (r *PongReply) ToBytes() []byte {
return PongBytes return pongBytes
} }
// OkReply is +OK
type OkReply struct{} type OkReply struct{}
var okBytes = []byte("+OK\r\n") var okBytes = []byte("+OK\r\n")
// ToBytes marshal redis.Reply
func (r *OkReply) ToBytes() []byte { func (r *OkReply) ToBytes() []byte {
return okBytes return okBytes
} }
var nullBulkBytes = []byte("$-1\r\n") var nullBulkBytes = []byte("$-1\r\n")
// NullBulkReply is empty string
type NullBulkReply struct{} type NullBulkReply struct{}
// ToBytes marshal redis.Reply
func (r *NullBulkReply) ToBytes() []byte { func (r *NullBulkReply) ToBytes() []byte {
return nullBulkBytes return nullBulkBytes
} }
// MakeNullBulkReply creates a new NullBulkReply
func MakeNullBulkReply() *NullBulkReply { func MakeNullBulkReply() *NullBulkReply {
return &NullBulkReply{} return &NullBulkReply{}
} }
var emptyMultiBulkBytes = []byte("*0\r\n") var emptyMultiBulkBytes = []byte("*0\r\n")
// EmptyMultiBulkReply is a empty list
type EmptyMultiBulkReply struct{} type EmptyMultiBulkReply struct{}
// ToBytes marshal redis.Reply
func (r *EmptyMultiBulkReply) ToBytes() []byte { func (r *EmptyMultiBulkReply) ToBytes() []byte {
return emptyMultiBulkBytes return emptyMultiBulkBytes
} }
// MakeEmptyMultiBulkReply creates EmptyMultiBulkReply
func MakeEmptyMultiBulkReply() *EmptyMultiBulkReply { func MakeEmptyMultiBulkReply() *EmptyMultiBulkReply {
return &EmptyMultiBulkReply{} return &EmptyMultiBulkReply{}
} }
// reply nothing, for commands like subscribe // NoReply respond nothing, for commands like subscribe
type NoReply struct{} type NoReply struct{}
var NoBytes = []byte("") var noBytes = []byte("")
// ToBytes marshal redis.Reply
func (r *NoReply) ToBytes() []byte { func (r *NoReply) ToBytes() []byte {
return NoBytes return noBytes
} }

View File

@@ -1,10 +1,11 @@
package reply package reply
// UnknownErr // UnknownErrReply represents UnknownErr
type UnknownErrReply struct{} type UnknownErrReply struct{}
var unknownErrBytes = []byte("-Err unknown\r\n") var unknownErrBytes = []byte("-Err unknown\r\n")
// ToBytes marshals redis.Reply
func (r *UnknownErrReply) ToBytes() []byte { func (r *UnknownErrReply) ToBytes() []byte {
return unknownErrBytes return unknownErrBytes
} }
@@ -13,11 +14,12 @@ func (r *UnknownErrReply) Error() string {
return "Err unknown" return "Err unknown"
} }
// ArgNumErr // ArgNumErrReply represents wrong number of arguments for command
type ArgNumErrReply struct { type ArgNumErrReply struct {
Cmd string Cmd string
} }
// ToBytes marshals redis.Reply
func (r *ArgNumErrReply) ToBytes() []byte { func (r *ArgNumErrReply) ToBytes() []byte {
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n") return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
} }
@@ -26,11 +28,12 @@ func (r *ArgNumErrReply) Error() string {
return "ERR wrong number of arguments for '" + r.Cmd + "' command" return "ERR wrong number of arguments for '" + r.Cmd + "' command"
} }
// SyntaxErr // SyntaxErrReply represents meeting unexpected arguments
type SyntaxErrReply struct{} type SyntaxErrReply struct{}
var syntaxErrBytes = []byte("-Err syntax error\r\n") var syntaxErrBytes = []byte("-Err syntax error\r\n")
// ToBytes marshals redis.Reply
func (r *SyntaxErrReply) ToBytes() []byte { func (r *SyntaxErrReply) ToBytes() []byte {
return syntaxErrBytes return syntaxErrBytes
} }
@@ -39,11 +42,12 @@ func (r *SyntaxErrReply) Error() string {
return "Err syntax error" return "Err syntax error"
} }
// WrongTypeErr // WrongTypeErrReply represents operation against a key holding the wrong kind of value
type WrongTypeErrReply struct{} type WrongTypeErrReply struct{}
var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n") var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
// ToBytes marshals redis.Reply
func (r *WrongTypeErrReply) ToBytes() []byte { func (r *WrongTypeErrReply) ToBytes() []byte {
return wrongTypeErrBytes return wrongTypeErrBytes
} }
@@ -54,10 +58,12 @@ func (r *WrongTypeErrReply) Error() string {
// ProtocolErr // ProtocolErr
// ProtocolErrReply represents meeting unexpected byte during parse requests
type ProtocolErrReply struct { type ProtocolErrReply struct {
Msg string Msg string
} }
// ToBytes marshals redis.Reply
func (r *ProtocolErrReply) ToBytes() []byte { func (r *ProtocolErrReply) ToBytes() []byte {
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n") return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
} }

View File

@@ -8,21 +8,26 @@ import (
var ( var (
nullBulkReplyBytes = []byte("$-1") nullBulkReplyBytes = []byte("$-1")
CRLF = "\r\n"
// CRLF is the line separator of redis serialization protocol
CRLF = "\r\n"
) )
/* ---- Bulk Reply ---- */ /* ---- Bulk Reply ---- */
// BulkReply stores a binary-safe string
type BulkReply struct { type BulkReply struct {
Arg []byte Arg []byte
} }
// MakeBulkReply creates BulkReply
func MakeBulkReply(arg []byte) *BulkReply { func MakeBulkReply(arg []byte) *BulkReply {
return &BulkReply{ return &BulkReply{
Arg: arg, Arg: arg,
} }
} }
// ToBytes marshal redis.Reply
func (r *BulkReply) ToBytes() []byte { func (r *BulkReply) ToBytes() []byte {
if len(r.Arg) == 0 { if len(r.Arg) == 0 {
return nullBulkReplyBytes return nullBulkReplyBytes
@@ -32,16 +37,19 @@ func (r *BulkReply) ToBytes() []byte {
/* ---- Multi Bulk Reply ---- */ /* ---- Multi Bulk Reply ---- */
// MultiBulkReply stores a list of string
type MultiBulkReply struct { type MultiBulkReply struct {
Args [][]byte Args [][]byte
} }
// MakeMultiBulkReply creates MultiBulkReply
func MakeMultiBulkReply(args [][]byte) *MultiBulkReply { func MakeMultiBulkReply(args [][]byte) *MultiBulkReply {
return &MultiBulkReply{ return &MultiBulkReply{
Args: args, Args: args,
} }
} }
// ToBytes marshal redis.Reply
func (r *MultiBulkReply) ToBytes() []byte { func (r *MultiBulkReply) ToBytes() []byte {
argLen := len(r.Args) argLen := len(r.Args)
var buf bytes.Buffer var buf bytes.Buffer
@@ -58,16 +66,19 @@ func (r *MultiBulkReply) ToBytes() []byte {
/* ---- Multi Raw Reply ---- */ /* ---- Multi Raw Reply ---- */
// MultiRawReply store complex list structure, for example GeoPos command
type MultiRawReply struct { type MultiRawReply struct {
Args [][]byte Args [][]byte
} }
// MakeMultiRawReply creates MultiRawReply
func MakeMultiRawReply(args [][]byte) *MultiRawReply { func MakeMultiRawReply(args [][]byte) *MultiRawReply {
return &MultiRawReply{ return &MultiRawReply{
Args: args, Args: args,
} }
} }
// ToBytes marshal redis.Reply
func (r *MultiRawReply) ToBytes() []byte { func (r *MultiRawReply) ToBytes() []byte {
argLen := len(r.Args) argLen := len(r.Args)
var buf bytes.Buffer var buf bytes.Buffer
@@ -80,57 +91,68 @@ func (r *MultiRawReply) ToBytes() []byte {
/* ---- Status Reply ---- */ /* ---- Status Reply ---- */
// StatusReply stores a simple status string
type StatusReply struct { type StatusReply struct {
Status string Status string
} }
// MakeStatusReply creates StatusReply
func MakeStatusReply(status string) *StatusReply { func MakeStatusReply(status string) *StatusReply {
return &StatusReply{ return &StatusReply{
Status: status, Status: status,
} }
} }
// ToBytes marshal redis.Reply
func (r *StatusReply) ToBytes() []byte { func (r *StatusReply) ToBytes() []byte {
return []byte("+" + r.Status + "\r\n") return []byte("+" + r.Status + "\r\n")
} }
/* ---- Int Reply ---- */ /* ---- Int Reply ---- */
// IntReply stores an int64 number
type IntReply struct { type IntReply struct {
Code int64 Code int64
} }
// MakeIntReply creates int reply
func MakeIntReply(code int64) *IntReply { func MakeIntReply(code int64) *IntReply {
return &IntReply{ return &IntReply{
Code: code, Code: code,
} }
} }
// ToBytes marshal redis.Reply
func (r *IntReply) ToBytes() []byte { func (r *IntReply) ToBytes() []byte {
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF) return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
} }
/* ---- Error Reply ---- */ /* ---- Error Reply ---- */
// ErrorReply is an error and redis.Reply
type ErrorReply interface { type ErrorReply interface {
Error() string Error() string
ToBytes() []byte ToBytes() []byte
} }
// StandardErrReply represents server error
type StandardErrReply struct { type StandardErrReply struct {
Status string Status string
} }
// MakeErrReply creates StandardErrReply
func MakeErrReply(status string) *StandardErrReply { func MakeErrReply(status string) *StandardErrReply {
return &StandardErrReply{ return &StandardErrReply{
Status: status, Status: status,
} }
} }
// IsErrorReply returns true if the given reply is error
func IsErrorReply(reply redis.Reply) bool { func IsErrorReply(reply redis.Reply) bool {
return reply.ToBytes()[0] == '-' return reply.ToBytes()[0] == '-'
} }
// ToBytes marshal redis.Reply
func (r *StandardErrReply) ToBytes() []byte { func (r *StandardErrReply) ToBytes() []byte {
return []byte("-" + r.Status + "\r\n") return []byte("-" + r.Status + "\r\n")
} }

View File

@@ -22,14 +22,14 @@ import (
) )
var ( var (
UnknownErrReplyBytes = []byte("-ERR unknown\r\n") unknownErrReplyBytes = []byte("-ERR unknown\r\n")
) )
// Handler implements tcp.Handler and serves as a redis server // Handler implements tcp.Handler and serves as a redis server
type Handler struct { type Handler struct {
activeConn sync.Map // *client -> placeholder activeConn sync.Map // *client -> placeholder
db db.DB db db.DB
closing atomic.AtomicBool // refusing new client and new request closing atomic.Boolean // refusing new client and new request
} }
// MakeHandler creates a Handler instance // MakeHandler creates a Handler instance
@@ -72,17 +72,16 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
h.closeClient(client) h.closeClient(client)
logger.Info("connection closed: " + client.RemoteAddr().String()) logger.Info("connection closed: " + client.RemoteAddr().String())
return return
} else {
// protocol err
errReply := reply.MakeErrReply(payload.Err.Error())
err := client.Write(errReply.ToBytes())
if err != nil {
h.closeClient(client)
logger.Info("connection closed: " + client.RemoteAddr().String())
return
}
continue
} }
// protocol err
errReply := reply.MakeErrReply(payload.Err.Error())
err := client.Write(errReply.ToBytes())
if err != nil {
h.closeClient(client)
logger.Info("connection closed: " + client.RemoteAddr().String())
return
}
continue
} }
if payload.Data == nil { if payload.Data == nil {
logger.Error("empty payload") logger.Error("empty payload")
@@ -97,7 +96,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
if result != nil { if result != nil {
_ = client.Write(result.ToBytes()) _ = client.Write(result.ToBytes())
} else { } else {
_ = client.Write(UnknownErrReplyBytes) _ = client.Write(unknownErrReplyBytes)
} }
} }
} }

View File

@@ -38,4 +38,4 @@ func isAuthenticated(c redis.Connection) bool {
return true return true
} }
return c.GetPassword() == config.Properties.RequirePass return c.GetPassword() == config.Properties.RequirePass
} }

View File

@@ -35,4 +35,4 @@ func TestAuth(t *testing.T) {
ret = Auth(testDB, c, utils.ToBytesList(passwd)) ret = Auth(testDB, c, utils.ToBytesList(passwd))
asserts.AssertStatusReply(t, ret, "OK") asserts.AssertStatusReply(t, ret, "OK")
} }

18
set.go
View File

@@ -190,7 +190,7 @@ func SInter(db *DB, args [][]byte) redis.Reply {
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.Make(set.ToSlice()...)
} else { } else {
result = result.Intersect(set) result = result.Intersect(set)
if result.Len() == 0 { if result.Len() == 0 {
@@ -241,7 +241,7 @@ func SInterStore(db *DB, args [][]byte) redis.Reply {
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.Make(set.ToSlice()...)
} else { } else {
result = result.Intersect(set) result = result.Intersect(set)
if result.Len() == 0 { if result.Len() == 0 {
@@ -252,7 +252,7 @@ func SInterStore(db *DB, args [][]byte) redis.Reply {
} }
} }
set := HashSet.MakeFromVals(result.ToSlice()...) set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &DataEntity{ db.PutEntity(dest, &DataEntity{
Data: set, Data: set,
}) })
@@ -286,7 +286,7 @@ func SUnion(db *DB, args [][]byte) redis.Reply {
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.Make(set.ToSlice()...)
} else { } else {
result = result.Union(set) result = result.Union(set)
} }
@@ -335,7 +335,7 @@ func SUnionStore(db *DB, args [][]byte) redis.Reply {
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.Make(set.ToSlice()...)
} else { } else {
result = result.Union(set) result = result.Union(set)
} }
@@ -347,7 +347,7 @@ func SUnionStore(db *DB, args [][]byte) redis.Reply {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
set := HashSet.MakeFromVals(result.ToSlice()...) set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &DataEntity{ db.PutEntity(dest, &DataEntity{
Data: set, Data: set,
}) })
@@ -385,7 +385,7 @@ func SDiff(db *DB, args [][]byte) redis.Reply {
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.Make(set.ToSlice()...)
} else { } else {
result = result.Diff(set) result = result.Diff(set)
if result.Len() == 0 { if result.Len() == 0 {
@@ -443,7 +443,7 @@ func SDiffStore(db *DB, args [][]byte) redis.Reply {
} }
if result == nil { if result == nil {
// init // init
result = HashSet.MakeFromVals(set.ToSlice()...) result = HashSet.Make(set.ToSlice()...)
} else { } else {
result = result.Diff(set) result = result.Diff(set)
if result.Len() == 0 { if result.Len() == 0 {
@@ -459,7 +459,7 @@ func SDiffStore(db *DB, args [][]byte) redis.Reply {
db.Remove(dest) db.Remove(dest)
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
set := HashSet.MakeFromVals(result.ToSlice()...) set := HashSet.Make(result.ToSlice()...)
db.PutEntity(dest, &DataEntity{ db.PutEntity(dest, &DataEntity{
Data: set, Data: set,
}) })

View File

@@ -156,6 +156,7 @@ func TestZRangeByScore(t *testing.T) {
setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i]) setArgs = append(setArgs, strconv.FormatInt(int64(scores[i]), 10), members[i])
} }
result := ZAdd(testDB, utils.ToBytesList(setArgs...)) result := ZAdd(testDB, utils.ToBytesList(setArgs...))
asserts.AssertIntReply(t, result, size)
min := "20" min := "20"
max := "30" max := "30"

View File

@@ -16,33 +16,38 @@ import (
"time" "time"
) )
// EchoHandler echos received line to client, using for test
type EchoHandler struct { type EchoHandler struct {
activeConn sync.Map activeConn sync.Map
closing atomic.AtomicBool closing atomic.Boolean
} }
// MakeEchoHandler creates EchoHandler
func MakeEchoHandler() *EchoHandler { func MakeEchoHandler() *EchoHandler {
return &EchoHandler{} return &EchoHandler{}
} }
type Client struct { // EchoClient is client for EchoHandler, using for test
type EchoClient struct {
Conn net.Conn Conn net.Conn
Waiting wait.Wait Waiting wait.Wait
} }
func (c *Client) Close() error { // Close close connection
func (c *EchoClient) Close() error {
c.Waiting.WaitWithTimeout(10 * time.Second) c.Waiting.WaitWithTimeout(10 * time.Second)
c.Conn.Close() c.Conn.Close()
return nil return nil
} }
// Handle echos received line to client
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) { func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() { if h.closing.Get() {
// closing handler refuse new connection // closing handler refuse new connection
conn.Close() conn.Close()
} }
client := &Client{ client := &EchoClient{
Conn: conn, Conn: conn,
} }
h.activeConn.Store(client, 1) h.activeConn.Store(client, 1)
@@ -69,12 +74,13 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
} }
} }
// Close stops echo hanlder
func (h *EchoHandler) Close() error { func (h *EchoHandler) Close() error {
logger.Info("handler shuting down...") logger.Info("handler shuting down...")
h.closing.Set(true) h.closing.Set(true)
// TODO: concurrent wait // TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{}) bool { h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*Client) client := key.(*EchoClient)
client.Close() client.Close()
return true return true
}) })

View File

@@ -18,12 +18,14 @@ import (
"time" "time"
) )
// Config stores tcp server properties
type Config struct { type Config struct {
Address string `yaml:"address"` Address string `yaml:"address"`
MaxConnect uint32 `yaml:"max-connect"` MaxConnect uint32 `yaml:"max-connect"`
Timeout time.Duration `yaml:"timeout"` Timeout time.Duration `yaml:"timeout"`
} }
// ListenAndServeWithSignal binds port and handle requests, blocking until receive stop signal
func ListenAndServeWithSignal(cfg *Config, handler tcp.Handler) error { func ListenAndServeWithSignal(cfg *Config, handler tcp.Handler) error {
closeChan := make(chan struct{}) closeChan := make(chan struct{})
sigCh := make(chan os.Signal) sigCh := make(chan os.Signal)
@@ -45,9 +47,10 @@ func ListenAndServeWithSignal(cfg *Config, handler tcp.Handler) error {
return nil return nil
} }
// ListenAndServe binds port and handle requests, blocking until close
func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) { func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) {
// listen signal // listen signal
var closing atomic.AtomicBool var closing atomic.Boolean
go func() { go func() {
<-closeChan <-closeChan
logger.Info("shutting down...") logger.Info("shutting down...")

View File

@@ -12,4 +12,3 @@ func makeTestDB() *DB {
locker: lock.Make(lockerSize), locker: lock.Make(lockerSize),
} }
} }