mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 08:46:56 +08:00
add some unittests and bug fix
This commit is contained in:
@@ -3,12 +3,12 @@ package cluster
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/hdt3213/godis/cluster/idgenerator"
|
||||
"github.com/hdt3213/godis/config"
|
||||
"github.com/hdt3213/godis/datastruct/dict"
|
||||
"github.com/hdt3213/godis/db"
|
||||
"github.com/hdt3213/godis/interface/redis"
|
||||
"github.com/hdt3213/godis/lib/consistenthash"
|
||||
"github.com/hdt3213/godis/lib/idgenerator"
|
||||
"github.com/hdt3213/godis/lib/logger"
|
||||
"github.com/hdt3213/godis/redis/reply"
|
||||
"github.com/jolestar/go-commons-pool/v2"
|
||||
@@ -44,7 +44,7 @@ func MakeCluster() *Cluster {
|
||||
peerPicker: consistenthash.New(replicas, nil),
|
||||
peerConnection: make(map[string]*pool.ObjectPool),
|
||||
|
||||
idGenerator: idgenerator.MakeGenerator("godis", config.Properties.Self),
|
||||
idGenerator: idgenerator.MakeGenerator(config.Properties.Self),
|
||||
}
|
||||
contains := make(map[string]struct{})
|
||||
nodes := make([]string, 0, len(config.Properties.Peers)+1)
|
||||
@@ -56,7 +56,7 @@ func MakeCluster() *Cluster {
|
||||
nodes = append(nodes, peer)
|
||||
}
|
||||
nodes = append(nodes, config.Properties.Self)
|
||||
cluster.peerPicker.Add(nodes...)
|
||||
cluster.peerPicker.AddNode(nodes...)
|
||||
ctx := context.Background()
|
||||
for _, peer := range nodes {
|
||||
cluster.peerConnection[peer] = pool.NewObjectPoolWithDefaultConfig(ctx, &ConnectionFactory{
|
||||
@@ -122,7 +122,7 @@ func makeArgs(cmd string, args ...string) [][]byte {
|
||||
func (cluster *Cluster) groupBy(keys []string) map[string][]string {
|
||||
result := make(map[string][]string)
|
||||
for _, key := range keys {
|
||||
peer := cluster.peerPicker.Get(key)
|
||||
peer := cluster.peerPicker.PickNode(key)
|
||||
group, ok := result[peer]
|
||||
if !ok {
|
||||
group = make([]string, 0)
|
||||
|
@@ -1,85 +0,0 @@
|
||||
package idgenerator
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
workerIdBits int64 = 5
|
||||
datacenterIdBits int64 = 5
|
||||
sequenceBits int64 = 12
|
||||
|
||||
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
|
||||
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
|
||||
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
|
||||
|
||||
timeLeft uint8 = 22
|
||||
dataLeft uint8 = 17
|
||||
workLeft uint8 = 12
|
||||
|
||||
twepoch int64 = 1525705533000
|
||||
)
|
||||
|
||||
type IdGenerator struct {
|
||||
mu *sync.Mutex
|
||||
lastStamp int64
|
||||
workerId int64
|
||||
dataCenterId int64
|
||||
sequence int64
|
||||
}
|
||||
|
||||
func MakeGenerator(cluster string, node string) *IdGenerator {
|
||||
fnv64 := fnv.New64()
|
||||
_, _ = fnv64.Write([]byte(cluster))
|
||||
dataCenterId := int64(fnv64.Sum64())
|
||||
|
||||
fnv64.Reset()
|
||||
_, _ = fnv64.Write([]byte(node))
|
||||
workerId := int64(fnv64.Sum64())
|
||||
|
||||
return &IdGenerator{
|
||||
mu: &sync.Mutex{},
|
||||
lastStamp: -1,
|
||||
dataCenterId: dataCenterId,
|
||||
workerId: workerId,
|
||||
sequence: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *IdGenerator) getCurrentTime() int64 {
|
||||
return time.Now().UnixNano() / 1e6
|
||||
}
|
||||
|
||||
func (w *IdGenerator) NextId() int64 {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
timestamp := w.getCurrentTime()
|
||||
if timestamp < w.lastStamp {
|
||||
log.Fatal("can not generate id")
|
||||
}
|
||||
if w.lastStamp == timestamp {
|
||||
w.sequence = (w.sequence + 1) & maxSequence
|
||||
if w.sequence == 0 {
|
||||
for timestamp <= w.lastStamp {
|
||||
timestamp = w.getCurrentTime()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
w.sequence = 0
|
||||
}
|
||||
w.lastStamp = timestamp
|
||||
|
||||
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
|
||||
}
|
||||
|
||||
func (w *IdGenerator) tilNextMillis() int64 {
|
||||
timestamp := w.getCurrentTime()
|
||||
if timestamp <= w.lastStamp {
|
||||
timestamp = w.getCurrentTime()
|
||||
}
|
||||
return timestamp
|
||||
}
|
@@ -146,7 +146,7 @@ func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
size := argCount / 2
|
||||
for i := 0; i < size; i++ {
|
||||
key := string(args[2*i])
|
||||
currentPeer := cluster.peerPicker.Get(key)
|
||||
currentPeer := cluster.peerPicker.PickNode(key)
|
||||
if peer == "" {
|
||||
peer = currentPeer
|
||||
} else {
|
||||
|
@@ -13,8 +13,8 @@ func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
src := string(args[1])
|
||||
dest := string(args[2])
|
||||
|
||||
srcPeer := cluster.peerPicker.Get(src)
|
||||
destPeer := cluster.peerPicker.Get(dest)
|
||||
srcPeer := cluster.peerPicker.PickNode(src)
|
||||
destPeer := cluster.peerPicker.PickNode(dest)
|
||||
|
||||
if srcPeer != destPeer {
|
||||
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
||||
@@ -29,8 +29,8 @@ func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
src := string(args[1])
|
||||
dest := string(args[2])
|
||||
|
||||
srcPeer := cluster.peerPicker.Get(src)
|
||||
destPeer := cluster.peerPicker.Get(dest)
|
||||
srcPeer := cluster.peerPicker.PickNode(src)
|
||||
destPeer := cluster.peerPicker.PickNode(dest)
|
||||
|
||||
if srcPeer != destPeer {
|
||||
return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
|
||||
|
@@ -116,6 +116,6 @@ func MakeRouter() map[string]CmdFunc {
|
||||
// relay command to responsible peer, and return its reply to client
|
||||
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
|
||||
key := string(args[1])
|
||||
peer := cluster.peerPicker.Get(key)
|
||||
peer := cluster.peerPicker.PickNode(key)
|
||||
return cluster.Relay(peer, c, args)
|
||||
}
|
||||
|
@@ -3,6 +3,7 @@ package config
|
||||
import (
|
||||
"bufio"
|
||||
"github.com/hdt3213/godis/lib/logger"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
@@ -31,18 +32,12 @@ func init() {
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(configFilename string) *PropertyHolder {
|
||||
config := Properties
|
||||
file, err := os.Open(configFilename)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
return config
|
||||
}
|
||||
defer file.Close()
|
||||
func parse(src io.Reader) *PropertyHolder {
|
||||
config := &PropertyHolder{}
|
||||
|
||||
// read config file
|
||||
rawMap := make(map[string]string)
|
||||
scanner := bufio.NewScanner(file)
|
||||
scanner := bufio.NewScanner(src)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if len(line) > 0 && line[0] == '#' {
|
||||
@@ -96,5 +91,10 @@ func LoadConfig(configFilename string) *PropertyHolder {
|
||||
}
|
||||
|
||||
func SetupConfig(configFilename string) {
|
||||
Properties = LoadConfig(configFilename)
|
||||
file, err := os.Open(configFilename)
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
defer file.Close()
|
||||
Properties = parse(file)
|
||||
}
|
||||
|
30
config/config_test.go
Normal file
30
config/config_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
src := "bind 0.0.0.0\n" +
|
||||
"port 6399\n" +
|
||||
"appendonly yes\n" +
|
||||
"peers a,b"
|
||||
p := parse(strings.NewReader(src))
|
||||
if p == nil {
|
||||
t.Error("cannot get result")
|
||||
return
|
||||
}
|
||||
if p.Bind != "0.0.0.0" {
|
||||
t.Error("string parse failed")
|
||||
}
|
||||
if p.Port != 6399 {
|
||||
t.Error("int parse failed")
|
||||
}
|
||||
if !p.AppendOnly {
|
||||
t.Error("bool parse failed")
|
||||
}
|
||||
if len(p.Peers) != 2 || p.Peers[0] != "a" || p.Peers[1] != "b" {
|
||||
t.Error("list parse failed")
|
||||
}
|
||||
}
|
@@ -6,7 +6,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPut(t *testing.T) {
|
||||
func TestConcurrentPut(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
@@ -35,7 +35,7 @@ func TestPut(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPutIfAbsent(t *testing.T) {
|
||||
func TestConcurrentPutIfAbsent(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
@@ -80,7 +80,7 @@ func TestPutIfAbsent(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestPutIfExists(t *testing.T) {
|
||||
func TestConcurrentPutIfExists(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
var wg sync.WaitGroup
|
||||
@@ -113,7 +113,7 @@ func TestPutIfExists(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestRemove(t *testing.T) {
|
||||
func TestConcurrentRemove(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
|
||||
// remove head node
|
||||
@@ -220,7 +220,7 @@ func TestRemove(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestForEach(t *testing.T) {
|
||||
func TestConcurrentForEach(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
size := 100
|
||||
for i := 0; i < size; i++ {
|
||||
@@ -242,3 +242,28 @@ func TestForEach(t *testing.T) {
|
||||
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentRandomKey(t *testing.T) {
|
||||
d := MakeConcurrent(0)
|
||||
count := 100
|
||||
for i := 0; i < count; i++ {
|
||||
key := "k" + strconv.Itoa(i)
|
||||
d.Put(key, i)
|
||||
}
|
||||
fetchSize := 10
|
||||
result := d.RandomKeys(fetchSize)
|
||||
if len(result) != fetchSize {
|
||||
t.Errorf("expect %d random keys acturally %d", fetchSize, len(result))
|
||||
}
|
||||
result = d.RandomDistinctKeys(fetchSize)
|
||||
distinct := make(map[string]struct{})
|
||||
for _, key := range result {
|
||||
distinct[key] = struct{}{}
|
||||
}
|
||||
if len(result) != fetchSize {
|
||||
t.Errorf("expect %d random keys acturally %d", fetchSize, len(result))
|
||||
}
|
||||
if len(result) > len(distinct) {
|
||||
t.Errorf("get duplicated keys in result")
|
||||
}
|
||||
}
|
@@ -4,8 +4,8 @@ type Connection interface {
|
||||
Write([]byte) error
|
||||
|
||||
// client should keep its subscribing channels
|
||||
SubsChannel(channel string)
|
||||
UnSubsChannel(channel string)
|
||||
Subscribe(channel string)
|
||||
UnSubscribe(channel string)
|
||||
SubsCount() int
|
||||
GetChannels() []string
|
||||
}
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
type HandleFunc func(ctx context.Context, conn net.Conn)
|
||||
|
||||
// Handler represents application server over tcp
|
||||
type Handler interface {
|
||||
Handle(ctx context.Context, conn net.Conn)
|
||||
Close() error
|
||||
|
@@ -32,7 +32,7 @@ func (m *Map) IsEmpty() bool {
|
||||
return len(m.keys) == 0
|
||||
}
|
||||
|
||||
func (m *Map) Add(keys ...string) {
|
||||
func (m *Map) AddNode(keys ...string) {
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
@@ -59,8 +59,8 @@ func getPartitionKey(key string) string {
|
||||
return key[beg+1 : end]
|
||||
}
|
||||
|
||||
// Get gets the closest item in the hash to the provided key.
|
||||
func (m *Map) Get(key string) string {
|
||||
// PickNode gets the closest item in the hash to the provided key.
|
||||
func (m *Map) PickNode(key string) string {
|
||||
if m.IsEmpty() {
|
||||
return ""
|
||||
}
|
||||
|
17
lib/consistenthash/consistenthash_test.go
Normal file
17
lib/consistenthash/consistenthash_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package consistenthash
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
m := New(3, nil)
|
||||
m.AddNode("a", "b", "c", "d")
|
||||
if m.PickNode("zxc") != "a" {
|
||||
t.Error("wrong answer")
|
||||
}
|
||||
if m.PickNode("123{abc}") != "b" {
|
||||
t.Error("wrong answer")
|
||||
}
|
||||
if m.PickNode("abc") != "b" {
|
||||
t.Error("wrong answer")
|
||||
}
|
||||
}
|
@@ -2,22 +2,9 @@ package files
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
func GetSize(f multipart.File) (int, error) {
|
||||
content, err := ioutil.ReadAll(f)
|
||||
|
||||
return len(content), err
|
||||
}
|
||||
|
||||
func GetExt(fileName string) string {
|
||||
return path.Ext(fileName)
|
||||
}
|
||||
|
||||
func CheckNotExist(src string) bool {
|
||||
_, err := os.Stat(src)
|
||||
|
||||
@@ -26,7 +13,6 @@ func CheckNotExist(src string) bool {
|
||||
|
||||
func CheckPermission(src string) bool {
|
||||
_, err := os.Stat(src)
|
||||
|
||||
return os.IsPermission(err)
|
||||
}
|
||||
|
||||
@@ -36,7 +22,6 @@ func IsNotExistMkDir(src string) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
67
lib/idgenerator/snowflake.go
Normal file
67
lib/idgenerator/snowflake.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package idgenerator
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// Epoch is set to the twitter snowflake epoch of Nov 04 2010 01:42:54 UTC in milliseconds
|
||||
// You may customize this to set a different epoch for your application.
|
||||
Epoch int64 = 1288834974657
|
||||
maxSequence int64 = -1 ^ (-1 << uint64(nodeLeft))
|
||||
timeLeft uint8 = 22
|
||||
nodeLeft uint8 = 10
|
||||
nodeMask int64 = -1 ^ (-1 << uint64(timeLeft-nodeLeft))
|
||||
)
|
||||
|
||||
type IdGenerator struct {
|
||||
mu *sync.Mutex
|
||||
lastStamp int64
|
||||
workerId int64
|
||||
sequence int64
|
||||
epoch time.Time
|
||||
}
|
||||
|
||||
func MakeGenerator(node string) *IdGenerator {
|
||||
fnv64 := fnv.New64()
|
||||
_, _ = fnv64.Write([]byte(node))
|
||||
nodeId := int64(fnv64.Sum64()) & nodeMask
|
||||
|
||||
var curTime = time.Now()
|
||||
epoch := curTime.Add(time.Unix(Epoch/1000, (Epoch%1000)*1000000).Sub(curTime))
|
||||
|
||||
return &IdGenerator{
|
||||
mu: &sync.Mutex{},
|
||||
lastStamp: -1,
|
||||
workerId: nodeId,
|
||||
sequence: 1,
|
||||
epoch: epoch,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *IdGenerator) NextId() int64 {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
timestamp := time.Since(w.epoch).Nanoseconds() / 1000000
|
||||
if timestamp < w.lastStamp {
|
||||
log.Fatal("can not generate id")
|
||||
}
|
||||
if w.lastStamp == timestamp {
|
||||
w.sequence = (w.sequence + 1) & maxSequence
|
||||
if w.sequence == 0 {
|
||||
for timestamp <= w.lastStamp {
|
||||
timestamp = time.Since(w.epoch).Nanoseconds() / 1000000
|
||||
}
|
||||
}
|
||||
} else {
|
||||
w.sequence = 0
|
||||
}
|
||||
w.lastStamp = timestamp
|
||||
id := (timestamp << timeLeft) | (w.workerId << nodeLeft) | w.sequence
|
||||
//fmt.Printf("%d %d %d\n", timestamp, w.sequence, id)
|
||||
return id
|
||||
}
|
17
lib/idgenerator/snowflake_test.go
Normal file
17
lib/idgenerator/snowflake_test.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package idgenerator
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMGenerator(t *testing.T) {
|
||||
gen := MakeGenerator("a")
|
||||
ids := make(map[int64]struct{})
|
||||
size := int(maxSequence) - 1
|
||||
for i := 0; i < size; i++ {
|
||||
id := gen.NextId()
|
||||
_, ok := ids[id]
|
||||
if ok {
|
||||
t.Errorf("duplicated id: %d, time: %d, seq: %d", id, gen.lastStamp, gen.sequence)
|
||||
}
|
||||
ids[id] = struct{}{}
|
||||
}
|
||||
}
|
@@ -25,7 +25,7 @@ func makeMsg(t string, channel string, code int64) []byte {
|
||||
* return: is new subscribed
|
||||
*/
|
||||
func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
||||
client.SubsChannel(channel)
|
||||
client.Subscribe(channel)
|
||||
|
||||
// add into hub.subs
|
||||
raw, ok := hub.subs.Get(channel)
|
||||
@@ -48,7 +48,7 @@ func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
||||
* return: is actually un-subscribe
|
||||
*/
|
||||
func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool {
|
||||
client.UnSubsChannel(channel)
|
||||
client.UnSubscribe(channel)
|
||||
|
||||
// remove from hub.subs
|
||||
raw, ok := hub.subs.Get(channel)
|
||||
|
@@ -7,8 +7,8 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// abstract of active client
|
||||
type Client struct {
|
||||
// Connection represents a connection with a redis-cli
|
||||
type Connection struct {
|
||||
conn net.Conn
|
||||
|
||||
// waiting util reply finished
|
||||
@@ -21,20 +21,23 @@ type Client struct {
|
||||
subs map[string]bool
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
// Close disconnect with the client
|
||||
func (c *Connection) Close() error {
|
||||
c.waitingReply.WaitWithTimeout(10 * time.Second)
|
||||
_ = c.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func MakeClient(conn net.Conn) *Client {
|
||||
return &Client{
|
||||
// NewConn creates Connection instance
|
||||
func NewConn(conn net.Conn) *Connection {
|
||||
return &Connection{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Write(b []byte) error {
|
||||
if b == nil || len(b) == 0 {
|
||||
// Write sends response to client over tcp connection
|
||||
func (c *Connection) Write(b []byte) error {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
c.mu.Lock()
|
||||
@@ -44,7 +47,8 @@ func (c *Client) Write(b []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) SubsChannel(channel string) {
|
||||
// Subscribe add current connection into subscribers of the given channel
|
||||
func (c *Connection) Subscribe(channel string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -54,7 +58,8 @@ func (c *Client) SubsChannel(channel string) {
|
||||
c.subs[channel] = true
|
||||
}
|
||||
|
||||
func (c *Client) UnSubsChannel(channel string) {
|
||||
// UnSubscribe removes current connection into subscribers of the given channel
|
||||
func (c *Connection) UnSubscribe(channel string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
@@ -64,14 +69,16 @@ func (c *Client) UnSubsChannel(channel string) {
|
||||
delete(c.subs, channel)
|
||||
}
|
||||
|
||||
func (c *Client) SubsCount() int {
|
||||
// SubsCount returns the number of subscribing channels
|
||||
func (c *Connection) SubsCount() int {
|
||||
if c.subs == nil {
|
||||
return 0
|
||||
}
|
||||
return len(c.subs)
|
||||
}
|
||||
|
||||
func (c *Client) GetChannels() []string {
|
||||
// GetChannels returns all subscribing channels
|
||||
func (c *Connection) GetChannels() []string {
|
||||
if c.subs == nil {
|
||||
return make([]string, 0)
|
||||
}
|
@@ -24,12 +24,14 @@ var (
|
||||
UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
|
||||
)
|
||||
|
||||
// Handler implements tcp.Handler and serves as a redis server
|
||||
type Handler struct {
|
||||
activeConn sync.Map // *client -> placeholder
|
||||
db db.DB
|
||||
closing atomic.AtomicBool // refusing new client and new request
|
||||
}
|
||||
|
||||
// MakeHandler creates a Handler instance
|
||||
func MakeHandler() *Handler {
|
||||
var db db.DB
|
||||
if config.Properties.Self != "" &&
|
||||
@@ -43,19 +45,20 @@ func MakeHandler() *Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) closeClient(client *Client) {
|
||||
func (h *Handler) closeClient(client *Connection) {
|
||||
_ = client.Close()
|
||||
h.db.AfterClientClose(client)
|
||||
h.activeConn.Delete(client)
|
||||
}
|
||||
|
||||
// Handle receives and executes redis commands
|
||||
func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
if h.closing.Get() {
|
||||
// closing handler refuse new connection
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
client := MakeClient(conn)
|
||||
client := NewConn(conn)
|
||||
h.activeConn.Store(client, 1)
|
||||
|
||||
ch := parser.Parse(conn)
|
||||
@@ -98,12 +101,13 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops handler
|
||||
func (h *Handler) Close() error {
|
||||
logger.Info("handler shuting down...")
|
||||
h.closing.Set(true)
|
||||
// TODO: concurrent wait
|
||||
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
||||
client := key.(*Client)
|
||||
client := key.(*Connection)
|
||||
_ = client.Close()
|
||||
return true
|
||||
})
|
42
redis/server/server_test.go
Normal file
42
redis/server/server_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"github.com/hdt3213/godis/tcp"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListenAndServe(t *testing.T) {
|
||||
var err error
|
||||
closeChan := make(chan struct{})
|
||||
listener, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
addr := listener.Addr().String()
|
||||
go tcp.ListenAndServe(listener, MakeHandler(), closeChan)
|
||||
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
_, err = conn.Write([]byte("PING\r\n"))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
bufReader := bufio.NewReader(conn)
|
||||
line, _, err := bufReader.ReadLine()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if string(line) != "+PONG" {
|
||||
t.Error("get wrong response")
|
||||
return
|
||||
}
|
||||
closeChan <- struct{}{}
|
||||
}
|
Reference in New Issue
Block a user