reformat code

This commit is contained in:
hdt3213
2021-04-03 20:14:12 +08:00
parent bf913a5aca
commit bcf0cd5e92
54 changed files with 4887 additions and 4896 deletions

View File

@@ -2,11 +2,13 @@
[中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md) [中文版](https://github.com/HDT3213/godis/blob/master/README_CN.md)
`Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent middleware using golang. `Godis` is a simple implementation of Redis Server, which intents to provide an example of writing a high concurrent
middleware using golang.
Please be advised, NEVER think about using this in production environment. Please be advised, NEVER think about using this in production environment.
This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence and server side cluster mode. This repository implemented most features of redis, including 5 data structures, ttl, publish/subscribe, AOF persistence
and server side cluster mode.
If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html). If you could read Chinese, you can find more details in [My Blog](https://www.cnblogs.com/Finley/category/1598973.html).
@@ -35,8 +37,7 @@ peers localhost:7379,localhost:7389 // other node in cluster
self localhost:6399 // self address self localhost:6399 // self address
``` ```
We provide node1.conf and node2.conf for demonstration. We provide node1.conf and node2.conf for demonstration. use following command line to start a two-node-cluster:
use following command line to start a two-node-cluster:
```bash ```bash
CONFIG=node1.conf ./godis-darwin & CONFIG=node1.conf ./godis-darwin &

View File

@@ -2,8 +2,8 @@ Godis 是一个用 Go 语言实现的 Redis 服务器。本项目旨在为尝试
**请注意:不要在生产环境使用使用此项目** **请注意:不要在生产环境使用使用此项目**
Godis 实现了 Redis 的大多数功能包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于 Godis 的信息。 Godis 实现了 Redis 的大多数功能包括5种数据结构、TTL、发布订阅以及 AOF 持久化。可以在[我的博客](https://www.cnblogs.com/Finley/category/1598973.html)了解更多关于
Godis 的信息。
# 运行 Godis # 运行 Godis

View File

@@ -37,4 +37,3 @@ func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
} }
return cluster.Relay(srcPeer, c, args) return cluster.Relay(srcPeer, c, args)
} }

View File

@@ -1,6 +1,6 @@
package dict package dict
type Consumer func(key string, val interface{})bool type Consumer func(key string, val interface{}) bool
type Dict interface { type Dict interface {
Get(key string) (val interface{}, exists bool) Get(key string) (val interface{}, exists bool)

View File

@@ -61,7 +61,7 @@ func TestPutIfAbsent(t *testing.T) {
} }
// update // update
ret = d.PutIfAbsent(key, i * 10) ret = d.PutIfAbsent(key, i*10)
if ret != 0 { // no update if ret != 0 { // no update
t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret)) t.Error("put test failed: expected result 0, actual: " + strconv.Itoa(ret))
} }
@@ -96,12 +96,12 @@ func TestPutIfExists(t *testing.T) {
} }
d.Put(key, i) d.Put(key, i)
ret = d.PutIfExists(key, 10 * i) ret = d.PutIfExists(key, 10*i)
val, ok := d.Get(key) val, ok := d.Get(key)
if ok { if ok {
intVal, _ := val.(int) intVal, _ := val.(int)
if intVal != 10 * i { if intVal != 10*i {
t.Error("put test failed: expected " + strconv.Itoa(10 * i) + ", actual: " + strconv.Itoa(intVal)) t.Error("put test failed: expected " + strconv.Itoa(10*i) + ", actual: " + strconv.Itoa(intVal))
} }
} else { } else {
_, ok := d.Get(key) _, ok := d.Get(key)
@@ -229,7 +229,7 @@ func TestForEach(t *testing.T) {
d.Put(key, i) d.Put(key, i)
} }
i := 0 i := 0
d.ForEach(func(key string, value interface{})bool { d.ForEach(func(key string, value interface{}) bool {
intVal, _ := value.(int) intVal, _ := value.(int)
expectedKey := "k" + strconv.Itoa(intVal) expectedKey := "k" + strconv.Itoa(intVal)
if key != expectedKey { if key != expectedKey {

View File

@@ -11,10 +11,10 @@ type LinkedList struct {
type node struct { type node struct {
val interface{} val interface{}
prev *node prev *node
next * node next *node
} }
func (list *LinkedList)Add(val interface{}) { func (list *LinkedList) Add(val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -33,8 +33,8 @@ func (list *LinkedList)Add(val interface{}) {
list.size++ list.size++
} }
func (list *LinkedList)find(index int)(n *node) { func (list *LinkedList) find(index int) (n *node) {
if index < list.size / 2 { if index < list.size/2 {
n := list.first n := list.first
for i := 0; i < index; i++ { for i := 0; i < index; i++ {
n = n.next n = n.next
@@ -49,7 +49,7 @@ func (list *LinkedList)find(index int)(n *node) {
} }
} }
func (list *LinkedList)Get(index int)(val interface{}) { func (list *LinkedList) Get(index int) (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -59,7 +59,7 @@ func (list *LinkedList)Get(index int)(val interface{}) {
return list.find(index).val return list.find(index).val
} }
func (list *LinkedList)Set(index int, val interface{}) { func (list *LinkedList) Set(index int, val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -70,7 +70,7 @@ func (list *LinkedList)Set(index int, val interface{}) {
n.val = val n.val = val
} }
func (list *LinkedList)Insert(index int, val interface{}) { func (list *LinkedList) Insert(index int, val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -99,7 +99,7 @@ func (list *LinkedList)Insert(index int, val interface{}) {
} }
} }
func (list *LinkedList)removeNode(n *node) { func (list *LinkedList) removeNode(n *node) {
if n.prev == nil { if n.prev == nil {
list.first = n.next list.first = n.next
} else { } else {
@@ -118,7 +118,7 @@ func (list *LinkedList)removeNode(n *node) {
list.size-- list.size--
} }
func (list *LinkedList)Remove(index int)(val interface{}) { func (list *LinkedList) Remove(index int) (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -131,7 +131,7 @@ func (list *LinkedList)Remove(index int)(val interface{}) {
return n.val return n.val
} }
func (list *LinkedList)RemoveLast()(val interface{}) { func (list *LinkedList) RemoveLast() (val interface{}) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -144,7 +144,7 @@ func (list *LinkedList)RemoveLast()(val interface{}) {
return n.val return n.val
} }
func (list *LinkedList)RemoveAllByVal(val interface{})int { func (list *LinkedList) RemoveAllByVal(val interface{}) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -176,7 +176,7 @@ func (list *LinkedList)RemoveAllByVal(val interface{})int {
* remove at most `count` values of the specified value in this list * remove at most `count` values of the specified value in this list
* scan from left to right * scan from left to right
*/ */
func (list *LinkedList) RemoveByVal(val interface{}, count int)int { func (list *LinkedList) RemoveByVal(val interface{}, count int) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -208,7 +208,7 @@ func (list *LinkedList) RemoveByVal(val interface{}, count int)int {
return removed return removed
} }
func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int)int { func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int) int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -240,14 +240,14 @@ func (list *LinkedList) ReverseRemoveByVal(val interface{}, count int)int {
return removed return removed
} }
func (list *LinkedList)Len()int { func (list *LinkedList) Len() int {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
return list.size return list.size
} }
func (list *LinkedList)ForEach(consumer func(int, interface{})bool) { func (list *LinkedList) ForEach(consumer func(int, interface{}) bool) {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }
@@ -264,7 +264,7 @@ func (list *LinkedList)ForEach(consumer func(int, interface{})bool) {
} }
} }
func (list *LinkedList)Contains(val interface{})bool { func (list *LinkedList) Contains(val interface{}) bool {
contains := false contains := false
list.ForEach(func(i int, actual interface{}) bool { list.ForEach(func(i int, actual interface{}) bool {
if actual == val { if actual == val {
@@ -276,7 +276,7 @@ func (list *LinkedList)Contains(val interface{})bool {
return contains return contains
} }
func (list *LinkedList)Range(start int, stop int)[]interface{} { func (list *LinkedList) Range(start int, stop int) []interface{} {
if list == nil { if list == nil {
panic("list is nil") panic("list is nil")
} }

View File

@@ -1,9 +1,9 @@
package list package list
import ( import (
"testing"
"strconv" "strconv"
"strings" "strings"
"testing"
) )
func ToString(list *LinkedList) string { func ToString(list *LinkedList) string {
@@ -137,7 +137,7 @@ func TestInsert(t *testing.T) {
list.ForEach(func(j int, v interface{}) bool { list.ForEach(func(j int, v interface{}) bool {
var expected int var expected int
if j < (i + 1) * 2 { if j < (i+1)*2 {
if j%2 == 0 { if j%2 == 0 {
expected = j / 2 expected = j / 2
} else { } else {
@@ -155,7 +155,7 @@ func TestInsert(t *testing.T) {
for j := 0; j < list.Len(); j++ { for j := 0; j < list.Len(); j++ {
var expected int var expected int
if j < (i + 1) * 2 { if j < (i+1)*2 {
if j%2 == 0 { if j%2 == 0 {
expected = j / 2 expected = j / 2
} else { } else {
@@ -196,8 +196,8 @@ func TestRange(t *testing.T) {
for start := 0; start < size; start++ { for start := 0; start < size; start++ {
for stop := start; stop < size; stop++ { for stop := start; stop < size; stop++ {
slice := list.Range(start, stop) slice := list.Range(start, stop)
if len(slice) != stop - start { if len(slice) != stop-start {
t.Error("expected " + strconv.Itoa(stop - start) + ", get: " + strconv.Itoa(len(slice)) + t.Error("expected " + strconv.Itoa(stop-start) + ", get: " + strconv.Itoa(len(slice)) +
", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]") ", range: [" + strconv.Itoa(start) + "," + strconv.Itoa(stop) + "]")
} }
sliceIndex := 0 sliceIndex := 0

View File

@@ -46,25 +46,25 @@ func (locks *Locks) spread(hashCode uint32) uint32 {
return (tableSize - 1) & uint32(hashCode) return (tableSize - 1) & uint32(hashCode)
} }
func (locks *Locks)Lock(key string) { func (locks *Locks) Lock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.Lock() mu.Lock()
} }
func (locks *Locks)RLock(key string) { func (locks *Locks) RLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.RLock() mu.RLock()
} }
func (locks *Locks)UnLock(key string) { func (locks *Locks) UnLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.Unlock() mu.Unlock()
} }
func (locks *Locks)RUnLock(key string) { func (locks *Locks) RUnLock(key string) {
index := locks.spread(fnv32(key)) index := locks.spread(fnv32(key))
mu := locks.table[index] mu := locks.table[index]
mu.RUnlock() mu.RUnlock()
@@ -90,7 +90,7 @@ func (locks *Locks) toLockIndices(keys []string, reverse bool) []uint32 {
return indices return indices
} }
func (locks *Locks)Locks(keys ...string) { func (locks *Locks) Locks(keys ...string) {
indices := locks.toLockIndices(keys, false) indices := locks.toLockIndices(keys, false)
for _, index := range indices { for _, index := range indices {
mu := locks.table[index] mu := locks.table[index]
@@ -98,7 +98,7 @@ func (locks *Locks)Locks(keys ...string) {
} }
} }
func (locks *Locks)RLocks(keys ...string) { func (locks *Locks) RLocks(keys ...string) {
indices := locks.toLockIndices(keys, false) indices := locks.toLockIndices(keys, false)
for _, index := range indices { for _, index := range indices {
mu := locks.table[index] mu := locks.table[index]
@@ -106,8 +106,7 @@ func (locks *Locks)RLocks(keys ...string) {
} }
} }
func (locks *Locks) UnLocks(keys ...string) {
func (locks *Locks)UnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true) indices := locks.toLockIndices(keys, true)
for _, index := range indices { for _, index := range indices {
mu := locks.table[index] mu := locks.table[index]
@@ -115,7 +114,7 @@ func (locks *Locks)UnLocks(keys ...string) {
} }
} }
func (locks *Locks)RUnLocks(keys ...string) { func (locks *Locks) RUnLocks(keys ...string) {
indices := locks.toLockIndices(keys, true) indices := locks.toLockIndices(keys, true)
for _, index := range indices { for _, index := range indices {
mu := locks.table[index] mu := locks.table[index]

View File

@@ -26,7 +26,7 @@ type ScoreBorder struct {
// if max.greater(score) then the score is within the upper border // if max.greater(score) then the score is within the upper border
// do not use min.greater() // do not use min.greater()
func (border *ScoreBorder)greater(value float64)bool { func (border *ScoreBorder) greater(value float64) bool {
if border.Inf == negativeInf { if border.Inf == negativeInf {
return false return false
} else if border.Inf == positiveInf { } else if border.Inf == positiveInf {
@@ -39,7 +39,7 @@ func (border *ScoreBorder)greater(value float64)bool {
} }
} }
func (border *ScoreBorder)less(value float64)bool { func (border *ScoreBorder) less(value float64) bool {
if border.Inf == negativeInf { if border.Inf == negativeInf {
return true return true
} else if border.Inf == positiveInf { } else if border.Inf == positiveInf {
@@ -52,15 +52,15 @@ func (border *ScoreBorder)less(value float64)bool {
} }
} }
var positiveInfBorder = &ScoreBorder { var positiveInfBorder = &ScoreBorder{
Inf: positiveInf, Inf: positiveInf,
} }
var negativeInfBorder = &ScoreBorder { var negativeInfBorder = &ScoreBorder{
Inf: negativeInf, Inf: negativeInf,
} }
func ParseScoreBorder(s string)(*ScoreBorder, error) { func ParseScoreBorder(s string) (*ScoreBorder, error) {
if s == "inf" || s == "+inf" { if s == "inf" || s == "+inf" {
return positiveInfBorder, nil return positiveInfBorder, nil
} }

View File

@@ -6,7 +6,6 @@ const (
maxLevel = 16 maxLevel = 16
) )
type Element struct { type Element struct {
Member string Member string
Score float64 Score float64
@@ -31,7 +30,7 @@ type skiplist struct {
level int16 level int16
} }
func makeNode(level int16, score float64, member string)*Node { func makeNode(level int16, score float64, member string) *Node {
n := &Node{ n := &Node{
Element: Element{ Element: Element{
Score: score, Score: score,
@@ -45,7 +44,7 @@ func makeNode(level int16, score float64, member string)*Node {
return n return n
} }
func makeSkiplist()*skiplist { func makeSkiplist() *skiplist {
return &skiplist{ return &skiplist{
level: 1, level: 1,
header: makeNode(maxLevel, 0, ""), header: makeNode(maxLevel, 0, ""),
@@ -63,17 +62,17 @@ func randomLevel() int16 {
return maxLevel return maxLevel
} }
func (skiplist *skiplist)insert(member string, score float64)*Node { func (skiplist *skiplist) insert(member string, score float64) *Node {
update := make([]*Node, maxLevel) // link new node with node in `update` update := make([]*Node, maxLevel) // link new node with node in `update`
rank := make([]int64, maxLevel) rank := make([]int64, maxLevel)
// find position to insert // find position to insert
node := skiplist.header node := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- { for i := skiplist.level - 1; i >= 0; i-- {
if i == skiplist.level - 1 { if i == skiplist.level-1 {
rank[i] = 0 rank[i] = 0
} else { } else {
rank[i] = rank[i + 1] // store rank that is crossed to reach the insert position rank[i] = rank[i+1] // store rank that is crossed to reach the insert position
} }
if node.level[i] != nil { if node.level[i] != nil {
// traverse the skip list // traverse the skip list
@@ -156,7 +155,7 @@ func (skiplist *skiplist) removeNode(node *Node, update []*Node) {
/* /*
* return: has found and removed node * return: has found and removed node
*/ */
func (skiplist *skiplist) remove(member string, score float64)bool { func (skiplist *skiplist) remove(member string, score float64) bool {
/* /*
* find backward node (of target) or last node of each level * find backward node (of target) or last node of each level
* their forward need to be updated * their forward need to be updated
@@ -184,7 +183,7 @@ func (skiplist *skiplist) remove(member string, score float64)bool {
/* /*
* return: 1 based rank, 0 means member not found * return: 1 based rank, 0 means member not found
*/ */
func (skiplist *skiplist) getRank(member string, score float64)int64 { func (skiplist *skiplist) getRank(member string, score float64) int64 {
var rank int64 = 0 var rank int64 = 0
x := skiplist.header x := skiplist.header
for i := skiplist.level - 1; i >= 0; i-- { for i := skiplist.level - 1; i >= 0; i-- {
@@ -207,7 +206,7 @@ func (skiplist *skiplist) getRank(member string, score float64)int64 {
/* /*
* 1-based rank * 1-based rank
*/ */
func (skiplist *skiplist) getByRank(rank int64)*Node { func (skiplist *skiplist) getByRank(rank int64) *Node {
var i int64 = 0 var i int64 = 0
n := skiplist.header n := skiplist.header
// scan from top level // scan from top level
@@ -281,7 +280,7 @@ func (skiplist *skiplist) getLastInScoreRange(min *ScoreBorder, max *ScoreBorder
/* /*
* return removed elements * return removed elements
*/ */
func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)(removed []*Element) { func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder) (removed []*Element) {
update := make([]*Node, maxLevel) update := make([]*Node, maxLevel)
removed = make([]*Element, 0) removed = make([]*Element, 0)
// find backward nodes (of target range) or last node of each level // find backward nodes (of target range) or last node of each level
@@ -314,7 +313,7 @@ func (skiplist *skiplist) RemoveRangeByScore(min *ScoreBorder, max *ScoreBorder)
} }
// 1-based rank, including start, exclude stop // 1-based rank, including start, exclude stop
func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64)(removed []*Element) { func (skiplist *skiplist) RemoveRangeByRank(start int64, stop int64) (removed []*Element) {
var i int64 = 0 // rank of iterator var i int64 = 0 // rank of iterator
update := make([]*Node, maxLevel) update := make([]*Node, maxLevel)
removed = make([]*Element, 0) removed = make([]*Element, 0)

View File

@@ -9,7 +9,7 @@ type SortedSet struct {
skiplist *skiplist skiplist *skiplist
} }
func Make()*SortedSet { func Make() *SortedSet {
return &SortedSet{ return &SortedSet{
dict: make(map[string]*Element), dict: make(map[string]*Element),
skiplist: makeSkiplist(), skiplist: makeSkiplist(),
@@ -19,7 +19,7 @@ func Make()*SortedSet {
/* /*
* return: has inserted new node * return: has inserted new node
*/ */
func (sortedSet *SortedSet)Add(member string, score float64)bool { func (sortedSet *SortedSet) Add(member string, score float64) bool {
element, ok := sortedSet.dict[member] element, ok := sortedSet.dict[member]
sortedSet.dict[member] = &Element{ sortedSet.dict[member] = &Element{
Member: member, Member: member,
@@ -37,7 +37,7 @@ func (sortedSet *SortedSet)Add(member string, score float64)bool {
} }
} }
func (sortedSet *SortedSet) Len()int64 { func (sortedSet *SortedSet) Len() int64 {
return int64(len(sortedSet.dict)) return int64(len(sortedSet.dict))
} }
@@ -49,7 +49,7 @@ func (sortedSet *SortedSet) Get(member string) (element *Element, ok bool) {
return element, true return element, true
} }
func (sortedSet *SortedSet) Remove(member string)bool { func (sortedSet *SortedSet) Remove(member string) bool {
v, ok := sortedSet.dict[member] v, ok := sortedSet.dict[member]
if ok { if ok {
sortedSet.skiplist.remove(member, v.Score) sortedSet.skiplist.remove(member, v.Score)
@@ -79,7 +79,7 @@ func (sortedSet *SortedSet) GetRank(member string, desc bool) (rank int64) {
/** /**
* traverse [start, stop), 0-based rank * traverse [start, stop), 0-based rank
*/ */
func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element)bool) { func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer func(element *Element) bool) {
size := int64(sortedSet.Len()) size := int64(sortedSet.Len())
if start < 0 || start >= size { if start < 0 || start >= size {
panic("illegal start " + strconv.FormatInt(start, 10)) panic("illegal start " + strconv.FormatInt(start, 10))
@@ -119,11 +119,11 @@ func (sortedSet *SortedSet) ForEach(start int64, stop int64, desc bool, consumer
* return [start, stop), 0-based rank * return [start, stop), 0-based rank
* assert start in [0, size), stop in [start, size] * assert start in [0, size), stop in [start, size]
*/ */
func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool)[]*Element { func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool) []*Element {
sliceSize := int(stop - start) sliceSize := int(stop - start)
slice := make([]*Element, sliceSize) slice := make([]*Element, sliceSize)
i := 0 i := 0
sortedSet.ForEach(start, stop, desc, func(element *Element)bool { sortedSet.ForEach(start, stop, desc, func(element *Element) bool {
slice[i] = element slice[i] = element
i++ i++
return true return true
@@ -131,7 +131,7 @@ func (sortedSet *SortedSet) Range(start int64, stop int64, desc bool)[]*Element
return slice return slice
} }
func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder)int64 { func (sortedSet *SortedSet) Count(min *ScoreBorder, max *ScoreBorder) int64 {
var i int64 = 0 var i int64 = 0
// ascending order // ascending order
sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool { sortedSet.ForEach(0, sortedSet.Len(), false, func(element *Element) bool {
@@ -194,8 +194,8 @@ func (sortedSet *SortedSet) ForEachByScore(min *ScoreBorder, max *ScoreBorder, o
/* /*
* param limit: <0 means no limit * param limit: <0 means no limit
*/ */
func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool)[]*Element { func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, offset int64, limit int64, desc bool) []*Element {
if limit == 0 || offset < 0{ if limit == 0 || offset < 0 {
return make([]*Element, 0) return make([]*Element, 0)
} }
slice := make([]*Element, 0) slice := make([]*Element, 0)
@@ -206,7 +206,7 @@ func (sortedSet *SortedSet) RangeByScore(min *ScoreBorder, max *ScoreBorder, off
return slice return slice
} }
func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder)int64 { func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder) int64 {
removed := sortedSet.skiplist.RemoveRangeByScore(min, max) removed := sortedSet.skiplist.RemoveRangeByScore(min, max)
for _, element := range removed { for _, element := range removed {
delete(sortedSet.dict, element.Member) delete(sortedSet.dict, element.Member)
@@ -214,12 +214,11 @@ func (sortedSet *SortedSet) RemoveByScore(min *ScoreBorder, max *ScoreBorder)int
return int64(len(removed)) return int64(len(removed))
} }
/* /*
* 0-based rank, [start, stop) * 0-based rank, [start, stop)
*/ */
func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64)int64 { func (sortedSet *SortedSet) RemoveByRank(start int64, stop int64) int64 {
removed := sortedSet.skiplist.RemoveRangeByRank(start + 1, stop + 1) removed := sortedSet.skiplist.RemoveRangeByRank(start+1, stop+1)
for _, element := range removed { for _, element := range removed {
delete(sortedSet.dict, element.Member) delete(sortedSet.dict, element.Member)
} }

View File

@@ -1,6 +1,6 @@
package utils package utils
func Equals(a interface{}, b interface{})bool { func Equals(a interface{}, b interface{}) bool {
sliceA, okA := a.([]byte) sliceA, okA := a.([]byte)
sliceB, okB := b.([]byte) sliceB, okB := b.([]byte)
if okA && okB { if okA && okB {

View File

@@ -140,7 +140,7 @@ func HDel(db *DB, args [][]byte) redis.Reply {
return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hdel' command")
} }
key := string(args[0]) key := string(args[0])
fields := make([]string, len(args) - 1) fields := make([]string, len(args)-1)
fieldArgs := args[1:] fieldArgs := args[1:]
for i, v := range fieldArgs { for i, v := range fieldArgs {
fields[i] = string(v) fields[i] = string(v)
@@ -192,7 +192,7 @@ func HLen(db *DB, args [][]byte) redis.Reply {
func HMSet(db *DB, args [][]byte) redis.Reply { func HMSet(db *DB, args [][]byte) redis.Reply {
// parse args // parse args
if len(args) < 3 || len(args) % 2 != 1 { if len(args) < 3 || len(args)%2 != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command") return reply.MakeErrReply("ERR wrong number of arguments for 'hmset' command")
} }
key := string(args[0]) key := string(args[0])
@@ -200,8 +200,8 @@ func HMSet(db *DB, args [][]byte) redis.Reply {
fields := make([]string, size) fields := make([]string, size)
values := make([][]byte, size) values := make([][]byte, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
fields[i] = string(args[2 * i + 1]) fields[i] = string(args[2*i+1])
values[i] = args[2 * i + 2] values[i] = args[2*i+2]
} }
// lock key // lock key
@@ -231,7 +231,7 @@ func HMGet(db *DB, args [][]byte) redis.Reply {
size := len(args) - 1 size := len(args) - 1
fields := make([]string, size) fields := make([]string, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
fields[i] = string(args[i + 1]) fields[i] = string(args[i+1])
} }
db.RLock(key) db.RLock(key)
@@ -278,7 +278,7 @@ func HKeys(db *DB, args [][]byte) redis.Reply {
fields := make([][]byte, dict.Len()) fields := make([][]byte, dict.Len())
i := 0 i := 0
dict.ForEach(func(key string, val interface{})bool { dict.ForEach(func(key string, val interface{}) bool {
fields[i] = []byte(key) fields[i] = []byte(key)
i++ i++
return true return true
@@ -306,7 +306,7 @@ func HVals(db *DB, args [][]byte) redis.Reply {
values := make([][]byte, dict.Len()) values := make([][]byte, dict.Len())
i := 0 i := 0
dict.ForEach(func(key string, val interface{})bool { dict.ForEach(func(key string, val interface{}) bool {
values[i], _ = val.([]byte) values[i], _ = val.([]byte)
i++ i++
return true return true
@@ -333,9 +333,9 @@ func HGetAll(db *DB, args [][]byte) redis.Reply {
} }
size := dict.Len() size := dict.Len()
result := make([][]byte, size * 2) result := make([][]byte, size*2)
i := 0 i := 0
dict.ForEach(func(key string, val interface{})bool { dict.ForEach(func(key string, val interface{}) bool {
result[i] = []byte(key) result[i] = []byte(key)
i++ i++
result[i], _ = val.([]byte) result[i], _ = val.([]byte)
@@ -414,7 +414,7 @@ func HIncrByFloat(db *DB, args [][]byte) redis.Reply {
return reply.MakeErrReply("ERR hash value is not a float") return reply.MakeErrReply("ERR hash value is not a float")
} }
result := val.Add(delta) result := val.Add(delta)
resultBytes:= []byte(result.String()) resultBytes := []byte(result.String())
dict.Put(field, resultBytes) dict.Put(field, resultBytes)
db.AddAof(makeAofCmd("hincrbyfloat", args)) db.AddAof(makeAofCmd("hincrbyfloat", args))
return reply.MakeBulkReply(resultBytes) return reply.MakeBulkReply(resultBytes)

View File

@@ -7,7 +7,7 @@ import (
"strconv" "strconv"
) )
func (db *DB) getAsList(key string)(*List.LinkedList, reply.ErrorReply) { func (db *DB) getAsList(key string) (*List.LinkedList, reply.ErrorReply) {
entity, ok := db.Get(key) entity, ok := db.Get(key)
if !ok { if !ok {
return nil, nil return nil, nil
@@ -19,7 +19,7 @@ func (db *DB) getAsList(key string)(*List.LinkedList, reply.ErrorReply) {
return bytes, nil return bytes, nil
} }
func (db *DB) getOrInitList(key string)(list *List.LinkedList, inited bool, errReply reply.ErrorReply) { func (db *DB) getOrInitList(key string) (list *List.LinkedList, inited bool, errReply reply.ErrorReply) {
list, errReply = db.getAsList(key) list, errReply = db.getAsList(key)
if errReply != nil { if errReply != nil {
return nil, false, errReply return nil, false, errReply
@@ -57,7 +57,7 @@ func LIndex(db *DB, args [][]byte) redis.Reply {
} }
size := list.Len() // assert: size > 0 size := list.Len() // assert: size > 0
if index < -1 * size { if index < -1*size {
return &reply.NullBulkReply{} return &reply.NullBulkReply{}
} else if index < 0 { } else if index < 0 {
index = size + index index = size + index
@@ -202,14 +202,14 @@ func LRange(db *DB, args [][]byte) redis.Reply {
// compute index // compute index
size := list.Len() // assert: size > 0 size := list.Len() // assert: size > 0
if start < -1 * size { if start < -1*size {
start = 0 start = 0
} else if start < 0 { } else if start < 0 {
start = size + start start = size + start
} else if start >= size { } else if start >= size {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
if stop < -1 * size { if stop < -1*size {
stop = 0 stop = 0
} else if stop < 0 { } else if stop < 0 {
stop = size + stop + 1 stop = size + stop + 1
@@ -304,7 +304,7 @@ func LSet(db *DB, args [][]byte) redis.Reply {
} }
size := list.Len() // assert: size > 0 size := list.Len() // assert: size > 0
if index < -1 * size { if index < -1*size {
return reply.MakeErrReply("ERR index out of range") return reply.MakeErrReply("ERR index out of range")
} else if index < 0 { } else if index < 0 {
index = size + index index = size + index

View File

@@ -1,6 +1,6 @@
package db package db
func MakeRouter()map[string]CmdFunc { func MakeRouter() map[string]CmdFunc {
routerMap := make(map[string]CmdFunc) routerMap := make(map[string]CmdFunc)
routerMap["ping"] = Ping routerMap["ping"] = Ping

View File

@@ -8,7 +8,7 @@ import (
"strings" "strings"
) )
func (db *DB)getAsSortedSet(key string)(*SortedSet.SortedSet, reply.ErrorReply) { func (db *DB) getAsSortedSet(key string) (*SortedSet.SortedSet, reply.ErrorReply) {
entity, exists := db.Get(key) entity, exists := db.Get(key)
if !exists { if !exists {
return nil, nil return nil, nil
@@ -20,7 +20,7 @@ func (db *DB)getAsSortedSet(key string)(*SortedSet.SortedSet, reply.ErrorReply)
return sortedSet, nil return sortedSet, nil
} }
func (db *DB) getOrInitSortedSet(key string)(sortedSet *SortedSet.SortedSet, inited bool, errReply reply.ErrorReply) { func (db *DB) getOrInitSortedSet(key string) (sortedSet *SortedSet.SortedSet, inited bool, errReply reply.ErrorReply) {
sortedSet, errReply = db.getAsSortedSet(key) sortedSet, errReply = db.getAsSortedSet(key)
if errReply != nil { if errReply != nil {
return nil, false, errReply return nil, false, errReply
@@ -37,22 +37,22 @@ func (db *DB) getOrInitSortedSet(key string)(sortedSet *SortedSet.SortedSet, ini
} }
func ZAdd(db *DB, args [][]byte) redis.Reply { func ZAdd(db *DB, args [][]byte) redis.Reply {
if len(args) < 3 || len(args) % 2 != 1 { if len(args) < 3 || len(args)%2 != 1 {
return reply.MakeErrReply("ERR wrong number of arguments for 'zadd' command") return reply.MakeErrReply("ERR wrong number of arguments for 'zadd' command")
} }
key := string(args[0]) key := string(args[0])
size := (len(args) - 1) / 2 size := (len(args) - 1) / 2
elements := make([]*SortedSet.Element, size) elements := make([]*SortedSet.Element, size)
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
scoreValue := args[2 * i + 1] scoreValue := args[2*i+1]
member := string(args[2 * i + 2]) member := string(args[2*i+2])
score, err := strconv.ParseFloat(string(scoreValue), 64) score, err := strconv.ParseFloat(string(scoreValue), 64)
if err != nil { if err != nil {
return reply.MakeErrReply("ERR value is not a valid float") return reply.MakeErrReply("ERR value is not a valid float")
} }
elements[i] = &SortedSet.Element{ elements[i] = &SortedSet.Element{
Member:member, Member: member,
Score:score, Score: score,
} }
} }
@@ -219,7 +219,7 @@ func ZRevRange(db *DB, args [][]byte) redis.Reply {
return range0(db, key, start, stop, withScores, true) return range0(db, key, start, stop, withScores, true)
} }
func range0(db *DB, key string, start int64, stop int64, withScores bool, desc bool)redis.Reply { func range0(db *DB, key string, start int64, stop int64, withScores bool, desc bool) redis.Reply {
// lock key // lock key
db.Locker.RLock(key) db.Locker.RLock(key)
defer db.Locker.RUnLock(key) defer db.Locker.RUnLock(key)
@@ -235,14 +235,14 @@ func range0(db *DB, key string, start int64, stop int64, withScores bool, desc b
// compute index // compute index
size := sortedSet.Len() // assert: size > 0 size := sortedSet.Len() // assert: size > 0
if start < -1 * size { if start < -1*size {
start = 0 start = 0
} else if start < 0 { } else if start < 0 {
start = size + start start = size + start
} else if start >= size { } else if start >= size {
return &reply.EmptyMultiBulkReply{} return &reply.EmptyMultiBulkReply{}
} }
if stop < -1 * size { if stop < -1*size {
stop = 0 stop = 0
} else if stop < 0 { } else if stop < 0 {
stop = size + stop + 1 stop = size + stop + 1
@@ -258,7 +258,7 @@ func range0(db *DB, key string, start int64, stop int64, withScores bool, desc b
// assert: start in [0, size - 1], stop in [start, size] // assert: start in [0, size - 1], stop in [start, size]
slice := sortedSet.Range(start, stop, desc) slice := sortedSet.Range(start, stop, desc)
if withScores { if withScores {
result := make([][]byte, len(slice) * 2) result := make([][]byte, len(slice)*2)
i := 0 i := 0
for _, element := range slice { for _, element := range slice {
result[i] = []byte(element.Member) result[i] = []byte(element.Member)
@@ -313,7 +313,7 @@ func ZCount(db *DB, args [][]byte) redis.Reply {
/* /*
* param limit: limit < 0 means no limit * param limit: limit < 0 means no limit
*/ */
func rangeByScore0(db *DB, key string, min *SortedSet.ScoreBorder, max *SortedSet.ScoreBorder, offset int64, limit int64, withScores bool, desc bool)redis.Reply { func rangeByScore0(db *DB, key string, min *SortedSet.ScoreBorder, max *SortedSet.ScoreBorder, offset int64, limit int64, withScores bool, desc bool) redis.Reply {
// lock key // lock key
db.Locker.RLock(key) db.Locker.RLock(key)
defer db.Locker.RUnLock(key) defer db.Locker.RUnLock(key)
@@ -329,7 +329,7 @@ func rangeByScore0(db *DB, key string, min *SortedSet.ScoreBorder, max *SortedSe
slice := sortedSet.RangeByScore(min, max, offset, limit, desc) slice := sortedSet.RangeByScore(min, max, offset, limit, desc)
if withScores { if withScores {
result := make([][]byte, len(slice) * 2) result := make([][]byte, len(slice)*2)
i := 0 i := 0
for _, element := range slice { for _, element := range slice {
result[i] = []byte(element.Member) result[i] = []byte(element.Member)
@@ -505,14 +505,14 @@ func ZRemRangeByRank(db *DB, args [][]byte) redis.Reply {
// compute index // compute index
size := sortedSet.Len() // assert: size > 0 size := sortedSet.Len() // assert: size > 0
if start < -1 * size { if start < -1*size {
start = 0 start = 0
} else if start < 0 { } else if start < 0 {
start = size + start start = size + start
} else if start >= size { } else if start >= size {
return reply.MakeIntReply(0) return reply.MakeIntReply(0)
} }
if stop < -1 * size { if stop < -1*size {
stop = 0 stop = 0
} else if stop < 0 { } else if stop < 0 {
stop = size + stop + 1 stop = size + stop + 1

View File

@@ -6,6 +6,6 @@ type Connection interface {
// client should keep its subscribing channels // client should keep its subscribing channels
SubsChannel(channel string) SubsChannel(channel string)
UnSubsChannel(channel string) UnSubsChannel(channel string)
SubsCount()int SubsCount() int
GetChannels()[]string GetChannels() []string
} }

View File

@@ -1,5 +1,5 @@
package redis package redis
type Reply interface { type Reply interface {
ToBytes()[]byte ToBytes() []byte
} }

View File

@@ -1,13 +1,13 @@
package tcp package tcp
import ( import (
"net"
"context" "context"
"net"
) )
type HandleFunc func(ctx context.Context, conn net.Conn) type HandleFunc func(ctx context.Context, conn net.Conn)
type Handler interface { type Handler interface {
Handle(ctx context.Context, conn net.Conn) Handle(ctx context.Context, conn net.Conn)
Close()error Close() error
} }

View File

@@ -1,11 +1,11 @@
package files package files
import ( import (
"mime/multipart"
"io/ioutil"
"path"
"os"
"fmt" "fmt"
"io/ioutil"
"mime/multipart"
"os"
"path"
) )
func GetSize(f multipart.File) (int, error) { func GetSize(f multipart.File) (int, error) {
@@ -69,7 +69,7 @@ func MustOpen(fileName, dir string) (*os.File, error) {
return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err) return nil, fmt.Errorf("error during make dir %s, err: %s", dir, err)
} }
f, err := Open(dir + string(os.PathSeparator) + fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644) f, err := Open(dir+string(os.PathSeparator)+fileName, os.O_APPEND|os.O_CREATE|os.O_RDWR, 0644)
if err != nil { if err != nil {
return nil, fmt.Errorf("fail to open file, err: %s", err) return nil, fmt.Errorf("fail to open file, err: %s", err)
} }

View File

@@ -69,7 +69,7 @@ func Distance(latitude1, longitude1, latitude2, longitude2 float64) float64 {
radLat2 := degRad(latitude2) radLat2 := degRad(latitude2)
a := radLat1 - radLat2 a := radLat1 - radLat2
b := degRad(longitude1) - degRad(longitude2) b := degRad(longitude1) - degRad(longitude2)
return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2) + return 2 * EarthRadius * math.Asin(math.Sqrt(math.Pow(math.Sin(a/2), 2)+
math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2))) math.Cos(radLat1)*math.Cos(radLat2)*math.Pow(math.Sin(b/2), 2)))
} }

View File

@@ -4,16 +4,14 @@ import "sync/atomic"
type AtomicBool uint32 type AtomicBool uint32
func (b *AtomicBool)Get()bool { func (b *AtomicBool) Get() bool {
return atomic.LoadUint32((*uint32)(b)) != 0 return atomic.LoadUint32((*uint32)(b)) != 0
} }
func (b *AtomicBool)Set(v bool) { func (b *AtomicBool) Set(v bool) {
if v { if v {
atomic.StoreUint32((*uint32)(b), 1) atomic.StoreUint32((*uint32)(b), 1)
} else { } else {
atomic.StoreUint32((*uint32)(b), 0) atomic.StoreUint32((*uint32)(b), 0)
} }
} }

View File

@@ -9,20 +9,20 @@ type Wait struct {
wg sync.WaitGroup wg sync.WaitGroup
} }
func (w *Wait)Add(delta int) { func (w *Wait) Add(delta int) {
w.wg.Add(delta) w.wg.Add(delta)
} }
func (w *Wait)Done() { func (w *Wait) Done() {
w.wg.Done() w.wg.Done()
} }
func (w *Wait)Wait() { func (w *Wait) Wait() {
w.wg.Wait() w.wg.Wait()
} }
// return isTimeout // return isTimeout
func (w *Wait)WaitWithTimeout(timeout time.Duration)bool { func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
c := make(chan bool) c := make(chan bool)
go func() { go func() {
defer close(c) defer close(c)

View File

@@ -310,7 +310,6 @@ func (client *Client) handleRead() error {
client.finishRequest(reply) client.finishRequest(reply)
} }
// finish reply // finish reply
expectedArgsCount = 0 expectedArgsCount = 0
receivedCount = 0 receivedCount = 0

View File

@@ -1,42 +1,42 @@
package reply package reply
type PongReply struct {} type PongReply struct{}
var PongBytes = []byte("+PONG\r\n") var PongBytes = []byte("+PONG\r\n")
func (r *PongReply)ToBytes()[]byte { func (r *PongReply) ToBytes() []byte {
return PongBytes return PongBytes
} }
type OkReply struct {} type OkReply struct{}
var okBytes = []byte("+OK\r\n") var okBytes = []byte("+OK\r\n")
func (r *OkReply)ToBytes()[]byte { func (r *OkReply) ToBytes() []byte {
return okBytes return okBytes
} }
var nullBulkBytes = []byte("$-1\r\n") var nullBulkBytes = []byte("$-1\r\n")
type NullBulkReply struct {} type NullBulkReply struct{}
func (r *NullBulkReply)ToBytes()[]byte { func (r *NullBulkReply) ToBytes() []byte {
return nullBulkBytes return nullBulkBytes
} }
var emptyMultiBulkBytes = []byte("*0\r\n") var emptyMultiBulkBytes = []byte("*0\r\n")
type EmptyMultiBulkReply struct {} type EmptyMultiBulkReply struct{}
func (r *EmptyMultiBulkReply)ToBytes()[]byte { func (r *EmptyMultiBulkReply) ToBytes() []byte {
return emptyMultiBulkBytes return emptyMultiBulkBytes
} }
// reply nothing, for commands like subscribe // reply nothing, for commands like subscribe
type NoReply struct {} type NoReply struct{}
var NoBytes = []byte("") var NoBytes = []byte("")
func (r *NoReply)ToBytes()[]byte { func (r *NoReply) ToBytes() []byte {
return NoBytes return NoBytes
} }

View File

@@ -1,15 +1,15 @@
package reply package reply
// UnknownErr // UnknownErr
type UnknownErrReply struct {} type UnknownErrReply struct{}
var unknownErrBytes = []byte("-Err unknown\r\n") var unknownErrBytes = []byte("-Err unknown\r\n")
func (r *UnknownErrReply)ToBytes()[]byte { func (r *UnknownErrReply) ToBytes() []byte {
return unknownErrBytes return unknownErrBytes
} }
func (r *UnknownErrReply) Error()string { func (r *UnknownErrReply) Error() string {
return "Err unknown" return "Err unknown"
} }
@@ -18,37 +18,37 @@ type ArgNumErrReply struct {
Cmd string Cmd string
} }
func (r *ArgNumErrReply)ToBytes()[]byte { func (r *ArgNumErrReply) ToBytes() []byte {
return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n") return []byte("-ERR wrong number of arguments for '" + r.Cmd + "' command\r\n")
} }
func (r *ArgNumErrReply) Error()string { func (r *ArgNumErrReply) Error() string {
return "ERR wrong number of arguments for '" + r.Cmd + "' command" return "ERR wrong number of arguments for '" + r.Cmd + "' command"
} }
// SyntaxErr // SyntaxErr
type SyntaxErrReply struct {} type SyntaxErrReply struct{}
var syntaxErrBytes = []byte("-Err syntax error\r\n") var syntaxErrBytes = []byte("-Err syntax error\r\n")
func (r *SyntaxErrReply)ToBytes()[]byte { func (r *SyntaxErrReply) ToBytes() []byte {
return syntaxErrBytes return syntaxErrBytes
} }
func (r *SyntaxErrReply)Error()string { func (r *SyntaxErrReply) Error() string {
return "Err syntax error" return "Err syntax error"
} }
// WrongTypeErr // WrongTypeErr
type WrongTypeErrReply struct {} type WrongTypeErrReply struct{}
var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n") var wrongTypeErrBytes = []byte("-WRONGTYPE Operation against a key holding the wrong kind of value\r\n")
func (r *WrongTypeErrReply)ToBytes()[]byte { func (r *WrongTypeErrReply) ToBytes() []byte {
return wrongTypeErrBytes return wrongTypeErrBytes
} }
func (r *WrongTypeErrReply)Error()string { func (r *WrongTypeErrReply) Error() string {
return "WRONGTYPE Operation against a key holding the wrong kind of value" return "WRONGTYPE Operation against a key holding the wrong kind of value"
} }
@@ -58,10 +58,10 @@ type ProtocolErrReply struct {
Msg string Msg string
} }
func (r *ProtocolErrReply)ToBytes()[]byte { func (r *ProtocolErrReply) ToBytes() []byte {
return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n") return []byte("-ERR Protocol error: '" + r.Msg + "'\r\n")
} }
func (r *ProtocolErrReply) Error()string { func (r *ProtocolErrReply) Error() string {
return "ERR Protocol error: '" + r.Msg return "ERR Protocol error: '" + r.Msg
} }

View File

@@ -110,7 +110,6 @@ func (r *IntReply) ToBytes() []byte {
return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF) return []byte(":" + strconv.FormatInt(r.Code, 10) + CRLF)
} }
/* ---- Error Reply ---- */ /* ---- Error Reply ---- */
type ErrorReply interface { type ErrorReply interface {

View File

@@ -31,7 +31,7 @@ type Client struct {
subs map[string]bool subs map[string]bool
} }
func (c *Client)Close()error { func (c *Client) Close() error {
c.waitingReply.WaitWithTimeout(10 * time.Second) c.waitingReply.WaitWithTimeout(10 * time.Second)
_ = c.conn.Close() _ = c.conn.Close()
return nil return nil
@@ -43,7 +43,7 @@ func MakeClient(conn net.Conn) *Client {
} }
} }
func (c *Client)Write(b []byte)error { func (c *Client) Write(b []byte) error {
if b == nil || len(b) == 0 { if b == nil || len(b) == 0 {
return nil return nil
} }
@@ -54,7 +54,7 @@ func (c *Client)Write(b []byte)error {
return err return err
} }
func (c *Client)SubsChannel(channel string) { func (c *Client) SubsChannel(channel string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -64,7 +64,7 @@ func (c *Client)SubsChannel(channel string) {
c.subs[channel] = true c.subs[channel] = true
} }
func (c *Client)UnSubsChannel(channel string) { func (c *Client) UnSubsChannel(channel string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -74,14 +74,14 @@ func (c *Client)UnSubsChannel(channel string) {
delete(c.subs, channel) delete(c.subs, channel)
} }
func (c *Client)SubsCount()int { func (c *Client) SubsCount() int {
if c.subs == nil { if c.subs == nil {
return 0 return 0
} }
return len(c.subs) return len(c.subs)
} }
func (c *Client)GetChannels()[]string { func (c *Client) GetChannels() []string {
if c.subs == nil { if c.subs == nil {
return make([]string, 0) return make([]string, 0)
} }

View File

@@ -5,15 +5,15 @@ package tcp
*/ */
import ( import (
"net"
"context"
"bufio" "bufio"
"context"
"github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/lib/logger"
"sync"
"io"
"github.com/HDT3213/godis/src/lib/sync/atomic" "github.com/HDT3213/godis/src/lib/sync/atomic"
"time"
"github.com/HDT3213/godis/src/lib/sync/wait" "github.com/HDT3213/godis/src/lib/sync/wait"
"io"
"net"
"sync"
"time"
) )
type EchoHandler struct { type EchoHandler struct {
@@ -21,9 +21,8 @@ type EchoHandler struct {
closing atomic.AtomicBool closing atomic.AtomicBool
} }
func MakeEchoHandler()(*EchoHandler) { func MakeEchoHandler() *EchoHandler {
return &EchoHandler{ return &EchoHandler{}
}
} }
type Client struct { type Client struct {
@@ -31,19 +30,19 @@ type Client struct {
Waiting wait.Wait Waiting wait.Wait
} }
func (c *Client)Close()error { func (c *Client) Close() error {
c.Waiting.WaitWithTimeout(10 * time.Second) c.Waiting.WaitWithTimeout(10 * time.Second)
c.Conn.Close() c.Conn.Close()
return nil return nil
} }
func (h *EchoHandler)Handle(ctx context.Context, conn net.Conn) { func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() { if h.closing.Get() {
// closing handler refuse new connection // closing handler refuse new connection
conn.Close() conn.Close()
} }
client := &Client { client := &Client{
Conn: conn, Conn: conn,
} }
h.activeConn.Store(client, 1) h.activeConn.Store(client, 1)
@@ -70,11 +69,11 @@ func (h *EchoHandler)Handle(ctx context.Context, conn net.Conn) {
} }
} }
func (h *EchoHandler)Close()error { func (h *EchoHandler) Close() error {
logger.Info("handler shuting down...") logger.Info("handler shuting down...")
h.closing.Set(true) h.closing.Set(true)
// TODO: concurrent wait // TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{})bool { h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*Client) client := key.(*Client)
client.Close() client.Close()
return true return true

View File

@@ -45,7 +45,6 @@ func ListenAndServe(cfg *Config, handler tcp.Handler) {
} }
}() }()
// listen port // listen port
logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address)) logger.Info(fmt.Sprintf("bind: %s, start listening...", cfg.Address))
defer func() { defer func() {