add some unittests and bug fix

This commit is contained in:
hdt3213
2021-05-02 18:48:44 +08:00
parent 37779717e4
commit 9f3ac88b36
19 changed files with 256 additions and 146 deletions

View File

@@ -3,12 +3,12 @@ package cluster
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/hdt3213/godis/cluster/idgenerator"
"github.com/hdt3213/godis/config" "github.com/hdt3213/godis/config"
"github.com/hdt3213/godis/datastruct/dict" "github.com/hdt3213/godis/datastruct/dict"
"github.com/hdt3213/godis/db" "github.com/hdt3213/godis/db"
"github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/interface/redis"
"github.com/hdt3213/godis/lib/consistenthash" "github.com/hdt3213/godis/lib/consistenthash"
"github.com/hdt3213/godis/lib/idgenerator"
"github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/redis/reply" "github.com/hdt3213/godis/redis/reply"
"github.com/jolestar/go-commons-pool/v2" "github.com/jolestar/go-commons-pool/v2"
@@ -44,7 +44,7 @@ func MakeCluster() *Cluster {
peerPicker: consistenthash.New(replicas, nil), peerPicker: consistenthash.New(replicas, nil),
peerConnection: make(map[string]*pool.ObjectPool), peerConnection: make(map[string]*pool.ObjectPool),
idGenerator: idgenerator.MakeGenerator("godis", config.Properties.Self), idGenerator: idgenerator.MakeGenerator(config.Properties.Self),
} }
contains := make(map[string]struct{}) contains := make(map[string]struct{})
nodes := make([]string, 0, len(config.Properties.Peers)+1) nodes := make([]string, 0, len(config.Properties.Peers)+1)
@@ -56,7 +56,7 @@ func MakeCluster() *Cluster {
nodes = append(nodes, peer) nodes = append(nodes, peer)
} }
nodes = append(nodes, config.Properties.Self) nodes = append(nodes, config.Properties.Self)
cluster.peerPicker.Add(nodes...) cluster.peerPicker.AddNode(nodes...)
ctx := context.Background() ctx := context.Background()
for _, peer := range nodes { for _, peer := range nodes {
cluster.peerConnection[peer] = pool.NewObjectPoolWithDefaultConfig(ctx, &ConnectionFactory{ cluster.peerConnection[peer] = pool.NewObjectPoolWithDefaultConfig(ctx, &ConnectionFactory{
@@ -122,7 +122,7 @@ func makeArgs(cmd string, args ...string) [][]byte {
func (cluster *Cluster) groupBy(keys []string) map[string][]string { func (cluster *Cluster) groupBy(keys []string) map[string][]string {
result := make(map[string][]string) result := make(map[string][]string)
for _, key := range keys { for _, key := range keys {
peer := cluster.peerPicker.Get(key) peer := cluster.peerPicker.PickNode(key)
group, ok := result[peer] group, ok := result[peer]
if !ok { if !ok {
group = make([]string, 0) group = make([]string, 0)

View File

@@ -1,85 +0,0 @@
package idgenerator
import (
"hash/fnv"
"log"
"sync"
"time"
)
const (
workerIdBits int64 = 5
datacenterIdBits int64 = 5
sequenceBits int64 = 12
maxWorkerId int64 = -1 ^ (-1 << uint64(workerIdBits))
maxDatacenterId int64 = -1 ^ (-1 << uint64(datacenterIdBits))
maxSequence int64 = -1 ^ (-1 << uint64(sequenceBits))
timeLeft uint8 = 22
dataLeft uint8 = 17
workLeft uint8 = 12
twepoch int64 = 1525705533000
)
type IdGenerator struct {
mu *sync.Mutex
lastStamp int64
workerId int64
dataCenterId int64
sequence int64
}
func MakeGenerator(cluster string, node string) *IdGenerator {
fnv64 := fnv.New64()
_, _ = fnv64.Write([]byte(cluster))
dataCenterId := int64(fnv64.Sum64())
fnv64.Reset()
_, _ = fnv64.Write([]byte(node))
workerId := int64(fnv64.Sum64())
return &IdGenerator{
mu: &sync.Mutex{},
lastStamp: -1,
dataCenterId: dataCenterId,
workerId: workerId,
sequence: 1,
}
}
func (w *IdGenerator) getCurrentTime() int64 {
return time.Now().UnixNano() / 1e6
}
func (w *IdGenerator) NextId() int64 {
w.mu.Lock()
defer w.mu.Unlock()
timestamp := w.getCurrentTime()
if timestamp < w.lastStamp {
log.Fatal("can not generate id")
}
if w.lastStamp == timestamp {
w.sequence = (w.sequence + 1) & maxSequence
if w.sequence == 0 {
for timestamp <= w.lastStamp {
timestamp = w.getCurrentTime()
}
}
} else {
w.sequence = 0
}
w.lastStamp = timestamp
return ((timestamp - twepoch) << timeLeft) | (w.dataCenterId << dataLeft) | (w.workerId << workLeft) | w.sequence
}
func (w *IdGenerator) tilNextMillis() int64 {
timestamp := w.getCurrentTime()
if timestamp <= w.lastStamp {
timestamp = w.getCurrentTime()
}
return timestamp
}

View File

@@ -146,7 +146,7 @@ func MSetNX(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
size := argCount / 2 size := argCount / 2
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
key := string(args[2*i]) key := string(args[2*i])
currentPeer := cluster.peerPicker.Get(key) currentPeer := cluster.peerPicker.PickNode(key)
if peer == "" { if peer == "" {
peer = currentPeer peer = currentPeer
} else { } else {

View File

@@ -13,8 +13,8 @@ func Rename(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
src := string(args[1]) src := string(args[1])
dest := string(args[2]) dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src) srcPeer := cluster.peerPicker.PickNode(src)
destPeer := cluster.peerPicker.Get(dest) destPeer := cluster.peerPicker.PickNode(dest)
if srcPeer != destPeer { if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode") return reply.MakeErrReply("ERR rename must within one slot in cluster mode")
@@ -29,8 +29,8 @@ func RenameNx(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
src := string(args[1]) src := string(args[1])
dest := string(args[2]) dest := string(args[2])
srcPeer := cluster.peerPicker.Get(src) srcPeer := cluster.peerPicker.PickNode(src)
destPeer := cluster.peerPicker.Get(dest) destPeer := cluster.peerPicker.PickNode(dest)
if srcPeer != destPeer { if srcPeer != destPeer {
return reply.MakeErrReply("ERR rename must within one slot in cluster mode") return reply.MakeErrReply("ERR rename must within one slot in cluster mode")

View File

@@ -116,6 +116,6 @@ func MakeRouter() map[string]CmdFunc {
// relay command to responsible peer, and return its reply to client // relay command to responsible peer, and return its reply to client
func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply { func defaultFunc(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
key := string(args[1]) key := string(args[1])
peer := cluster.peerPicker.Get(key) peer := cluster.peerPicker.PickNode(key)
return cluster.Relay(peer, c, args) return cluster.Relay(peer, c, args)
} }

View File

@@ -3,6 +3,7 @@ package config
import ( import (
"bufio" "bufio"
"github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/logger"
"io"
"log" "log"
"os" "os"
"reflect" "reflect"
@@ -31,18 +32,12 @@ func init() {
} }
} }
func LoadConfig(configFilename string) *PropertyHolder { func parse(src io.Reader) *PropertyHolder {
config := Properties config := &PropertyHolder{}
file, err := os.Open(configFilename)
if err != nil {
log.Print(err)
return config
}
defer file.Close()
// read config file // read config file
rawMap := make(map[string]string) rawMap := make(map[string]string)
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(src)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if len(line) > 0 && line[0] == '#' { if len(line) > 0 && line[0] == '#' {
@@ -96,5 +91,10 @@ func LoadConfig(configFilename string) *PropertyHolder {
} }
func SetupConfig(configFilename string) { func SetupConfig(configFilename string) {
Properties = LoadConfig(configFilename) file, err := os.Open(configFilename)
if err != nil {
log.Print(err)
}
defer file.Close()
Properties = parse(file)
} }

30
config/config_test.go Normal file
View File

@@ -0,0 +1,30 @@
package config
import (
"strings"
"testing"
)
func TestParse(t *testing.T) {
src := "bind 0.0.0.0\n" +
"port 6399\n" +
"appendonly yes\n" +
"peers a,b"
p := parse(strings.NewReader(src))
if p == nil {
t.Error("cannot get result")
return
}
if p.Bind != "0.0.0.0" {
t.Error("string parse failed")
}
if p.Port != 6399 {
t.Error("int parse failed")
}
if !p.AppendOnly {
t.Error("bool parse failed")
}
if len(p.Peers) != 2 || p.Peers[0] != "a" || p.Peers[1] != "b" {
t.Error("list parse failed")
}
}

View File

@@ -6,7 +6,7 @@ import (
"testing" "testing"
) )
func TestPut(t *testing.T) { func TestConcurrentPut(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
count := 100 count := 100
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -35,7 +35,7 @@ func TestPut(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestPutIfAbsent(t *testing.T) { func TestConcurrentPutIfAbsent(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
count := 100 count := 100
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -80,7 +80,7 @@ func TestPutIfAbsent(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestPutIfExists(t *testing.T) { func TestConcurrentPutIfExists(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
count := 100 count := 100
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -113,7 +113,7 @@ func TestPutIfExists(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestRemove(t *testing.T) { func TestConcurrentRemove(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
// remove head node // remove head node
@@ -220,7 +220,7 @@ func TestRemove(t *testing.T) {
} }
} }
func TestForEach(t *testing.T) { func TestConcurrentForEach(t *testing.T) {
d := MakeConcurrent(0) d := MakeConcurrent(0)
size := 100 size := 100
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
@@ -242,3 +242,28 @@ func TestForEach(t *testing.T) {
t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i)) t.Error("remove test failed: expected " + strconv.Itoa(size) + ", actual: " + strconv.Itoa(i))
} }
} }
func TestConcurrentRandomKey(t *testing.T) {
d := MakeConcurrent(0)
count := 100
for i := 0; i < count; i++ {
key := "k" + strconv.Itoa(i)
d.Put(key, i)
}
fetchSize := 10
result := d.RandomKeys(fetchSize)
if len(result) != fetchSize {
t.Errorf("expect %d random keys acturally %d", fetchSize, len(result))
}
result = d.RandomDistinctKeys(fetchSize)
distinct := make(map[string]struct{})
for _, key := range result {
distinct[key] = struct{}{}
}
if len(result) != fetchSize {
t.Errorf("expect %d random keys acturally %d", fetchSize, len(result))
}
if len(result) > len(distinct) {
t.Errorf("get duplicated keys in result")
}
}

View File

@@ -4,8 +4,8 @@ type Connection interface {
Write([]byte) error Write([]byte) error
// client should keep its subscribing channels // client should keep its subscribing channels
SubsChannel(channel string) Subscribe(channel string)
UnSubsChannel(channel string) UnSubscribe(channel string)
SubsCount() int SubsCount() int
GetChannels() []string GetChannels() []string
} }

View File

@@ -7,6 +7,7 @@ import (
type HandleFunc func(ctx context.Context, conn net.Conn) type HandleFunc func(ctx context.Context, conn net.Conn)
// Handler represents application server over tcp
type Handler interface { type Handler interface {
Handle(ctx context.Context, conn net.Conn) Handle(ctx context.Context, conn net.Conn)
Close() error Close() error

View File

@@ -32,7 +32,7 @@ func (m *Map) IsEmpty() bool {
return len(m.keys) == 0 return len(m.keys) == 0
} }
func (m *Map) Add(keys ...string) { func (m *Map) AddNode(keys ...string) {
for _, key := range keys { for _, key := range keys {
if key == "" { if key == "" {
continue continue
@@ -59,8 +59,8 @@ func getPartitionKey(key string) string {
return key[beg+1 : end] return key[beg+1 : end]
} }
// Get gets the closest item in the hash to the provided key. // PickNode gets the closest item in the hash to the provided key.
func (m *Map) Get(key string) string { func (m *Map) PickNode(key string) string {
if m.IsEmpty() { if m.IsEmpty() {
return "" return ""
} }

View File

@@ -0,0 +1,17 @@
package consistenthash
import "testing"
func TestHash(t *testing.T) {
m := New(3, nil)
m.AddNode("a", "b", "c", "d")
if m.PickNode("zxc") != "a" {
t.Error("wrong answer")
}
if m.PickNode("123{abc}") != "b" {
t.Error("wrong answer")
}
if m.PickNode("abc") != "b" {
t.Error("wrong answer")
}
}

View File

@@ -2,22 +2,9 @@ package files
import ( import (
"fmt" "fmt"
"io/ioutil"
"mime/multipart"
"os" "os"
"path"
) )
func GetSize(f multipart.File) (int, error) {
content, err := ioutil.ReadAll(f)
return len(content), err
}
func GetExt(fileName string) string {
return path.Ext(fileName)
}
func CheckNotExist(src string) bool { func CheckNotExist(src string) bool {
_, err := os.Stat(src) _, err := os.Stat(src)
@@ -26,7 +13,6 @@ func CheckNotExist(src string) bool {
func CheckPermission(src string) bool { func CheckPermission(src string) bool {
_, err := os.Stat(src) _, err := os.Stat(src)
return os.IsPermission(err) return os.IsPermission(err)
} }
@@ -36,7 +22,6 @@ func IsNotExistMkDir(src string) error {
return err return err
} }
} }
return nil return nil
} }

View File

@@ -0,0 +1,67 @@
package idgenerator
import (
"hash/fnv"
"log"
"sync"
"time"
)
const (
// Epoch is set to the twitter snowflake epoch of Nov 04 2010 01:42:54 UTC in milliseconds
// You may customize this to set a different epoch for your application.
Epoch int64 = 1288834974657
maxSequence int64 = -1 ^ (-1 << uint64(nodeLeft))
timeLeft uint8 = 22
nodeLeft uint8 = 10
nodeMask int64 = -1 ^ (-1 << uint64(timeLeft-nodeLeft))
)
type IdGenerator struct {
mu *sync.Mutex
lastStamp int64
workerId int64
sequence int64
epoch time.Time
}
func MakeGenerator(node string) *IdGenerator {
fnv64 := fnv.New64()
_, _ = fnv64.Write([]byte(node))
nodeId := int64(fnv64.Sum64()) & nodeMask
var curTime = time.Now()
epoch := curTime.Add(time.Unix(Epoch/1000, (Epoch%1000)*1000000).Sub(curTime))
return &IdGenerator{
mu: &sync.Mutex{},
lastStamp: -1,
workerId: nodeId,
sequence: 1,
epoch: epoch,
}
}
func (w *IdGenerator) NextId() int64 {
w.mu.Lock()
defer w.mu.Unlock()
timestamp := time.Since(w.epoch).Nanoseconds() / 1000000
if timestamp < w.lastStamp {
log.Fatal("can not generate id")
}
if w.lastStamp == timestamp {
w.sequence = (w.sequence + 1) & maxSequence
if w.sequence == 0 {
for timestamp <= w.lastStamp {
timestamp = time.Since(w.epoch).Nanoseconds() / 1000000
}
}
} else {
w.sequence = 0
}
w.lastStamp = timestamp
id := (timestamp << timeLeft) | (w.workerId << nodeLeft) | w.sequence
//fmt.Printf("%d %d %d\n", timestamp, w.sequence, id)
return id
}

View File

@@ -0,0 +1,17 @@
package idgenerator
import "testing"
func TestMGenerator(t *testing.T) {
gen := MakeGenerator("a")
ids := make(map[int64]struct{})
size := int(maxSequence) - 1
for i := 0; i < size; i++ {
id := gen.NextId()
_, ok := ids[id]
if ok {
t.Errorf("duplicated id: %d, time: %d, seq: %d", id, gen.lastStamp, gen.sequence)
}
ids[id] = struct{}{}
}
}

View File

@@ -25,7 +25,7 @@ func makeMsg(t string, channel string, code int64) []byte {
* return: is new subscribed * return: is new subscribed
*/ */
func subscribe0(hub *Hub, channel string, client redis.Connection) bool { func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
client.SubsChannel(channel) client.Subscribe(channel)
// add into hub.subs // add into hub.subs
raw, ok := hub.subs.Get(channel) raw, ok := hub.subs.Get(channel)
@@ -48,7 +48,7 @@ func subscribe0(hub *Hub, channel string, client redis.Connection) bool {
* return: is actually un-subscribe * return: is actually un-subscribe
*/ */
func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool { func unsubscribe0(hub *Hub, channel string, client redis.Connection) bool {
client.UnSubsChannel(channel) client.UnSubscribe(channel)
// remove from hub.subs // remove from hub.subs
raw, ok := hub.subs.Get(channel) raw, ok := hub.subs.Get(channel)

View File

@@ -7,8 +7,8 @@ import (
"time" "time"
) )
// abstract of active client // Connection represents a connection with a redis-cli
type Client struct { type Connection struct {
conn net.Conn conn net.Conn
// waiting util reply finished // waiting util reply finished
@@ -21,20 +21,23 @@ type Client struct {
subs map[string]bool subs map[string]bool
} }
func (c *Client) Close() error { // Close disconnect with the client
func (c *Connection) Close() error {
c.waitingReply.WaitWithTimeout(10 * time.Second) c.waitingReply.WaitWithTimeout(10 * time.Second)
_ = c.conn.Close() _ = c.conn.Close()
return nil return nil
} }
func MakeClient(conn net.Conn) *Client { // NewConn creates Connection instance
return &Client{ func NewConn(conn net.Conn) *Connection {
return &Connection{
conn: conn, conn: conn,
} }
} }
func (c *Client) Write(b []byte) error { // Write sends response to client over tcp connection
if b == nil || len(b) == 0 { func (c *Connection) Write(b []byte) error {
if len(b) == 0 {
return nil return nil
} }
c.mu.Lock() c.mu.Lock()
@@ -44,7 +47,8 @@ func (c *Client) Write(b []byte) error {
return err return err
} }
func (c *Client) SubsChannel(channel string) { // Subscribe add current connection into subscribers of the given channel
func (c *Connection) Subscribe(channel string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -54,7 +58,8 @@ func (c *Client) SubsChannel(channel string) {
c.subs[channel] = true c.subs[channel] = true
} }
func (c *Client) UnSubsChannel(channel string) { // UnSubscribe removes current connection into subscribers of the given channel
func (c *Connection) UnSubscribe(channel string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -64,14 +69,16 @@ func (c *Client) UnSubsChannel(channel string) {
delete(c.subs, channel) delete(c.subs, channel)
} }
func (c *Client) SubsCount() int { // SubsCount returns the number of subscribing channels
func (c *Connection) SubsCount() int {
if c.subs == nil { if c.subs == nil {
return 0 return 0
} }
return len(c.subs) return len(c.subs)
} }
func (c *Client) GetChannels() []string { // GetChannels returns all subscribing channels
func (c *Connection) GetChannels() []string {
if c.subs == nil { if c.subs == nil {
return make([]string, 0) return make([]string, 0)
} }

View File

@@ -24,12 +24,14 @@ var (
UnknownErrReplyBytes = []byte("-ERR unknown\r\n") UnknownErrReplyBytes = []byte("-ERR unknown\r\n")
) )
// Handler implements tcp.Handler and serves as a redis server
type Handler struct { type Handler struct {
activeConn sync.Map // *client -> placeholder activeConn sync.Map // *client -> placeholder
db db.DB db db.DB
closing atomic.AtomicBool // refusing new client and new request closing atomic.AtomicBool // refusing new client and new request
} }
// MakeHandler creates a Handler instance
func MakeHandler() *Handler { func MakeHandler() *Handler {
var db db.DB var db db.DB
if config.Properties.Self != "" && if config.Properties.Self != "" &&
@@ -43,19 +45,20 @@ func MakeHandler() *Handler {
} }
} }
func (h *Handler) closeClient(client *Client) { func (h *Handler) closeClient(client *Connection) {
_ = client.Close() _ = client.Close()
h.db.AfterClientClose(client) h.db.AfterClientClose(client)
h.activeConn.Delete(client) h.activeConn.Delete(client)
} }
// Handle receives and executes redis commands
func (h *Handler) Handle(ctx context.Context, conn net.Conn) { func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() { if h.closing.Get() {
// closing handler refuse new connection // closing handler refuse new connection
_ = conn.Close() _ = conn.Close()
} }
client := MakeClient(conn) client := NewConn(conn)
h.activeConn.Store(client, 1) h.activeConn.Store(client, 1)
ch := parser.Parse(conn) ch := parser.Parse(conn)
@@ -98,12 +101,13 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
} }
} }
// Close stops handler
func (h *Handler) Close() error { func (h *Handler) Close() error {
logger.Info("handler shuting down...") logger.Info("handler shuting down...")
h.closing.Set(true) h.closing.Set(true)
// TODO: concurrent wait // TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{}) bool { h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*Client) client := key.(*Connection)
_ = client.Close() _ = client.Close()
return true return true
}) })

View File

@@ -0,0 +1,42 @@
package server
import (
"bufio"
"github.com/hdt3213/godis/tcp"
"net"
"testing"
)
func TestListenAndServe(t *testing.T) {
var err error
closeChan := make(chan struct{})
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Error(err)
return
}
addr := listener.Addr().String()
go tcp.ListenAndServe(listener, MakeHandler(), closeChan)
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Error(err)
return
}
_, err = conn.Write([]byte("PING\r\n"))
if err != nil {
t.Error(err)
return
}
bufReader := bufio.NewReader(conn)
line, _, err := bufReader.ReadLine()
if err != nil {
t.Error(err)
return
}
if string(line) != "+PONG" {
t.Error("get wrong response")
return
}
closeChan <- struct{}{}
}