mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-30 03:31:54 +08:00
add unittests and bug fix
This commit is contained in:
@@ -1,92 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/hdt3213/godis/lib/sync/wait"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Connection represents a connection with a redis-cli
|
||||
type Connection struct {
|
||||
conn net.Conn
|
||||
|
||||
// waiting util reply finished
|
||||
waitingReply wait.Wait
|
||||
|
||||
// lock while server sending response
|
||||
mu sync.Mutex
|
||||
|
||||
// subscribing channels
|
||||
subs map[string]bool
|
||||
}
|
||||
|
||||
// Close disconnect with the client
|
||||
func (c *Connection) Close() error {
|
||||
c.waitingReply.WaitWithTimeout(10 * time.Second)
|
||||
_ = c.conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewConn creates Connection instance
|
||||
func NewConn(conn net.Conn) *Connection {
|
||||
return &Connection{
|
||||
conn: conn,
|
||||
}
|
||||
}
|
||||
|
||||
// Write sends response to client over tcp connection
|
||||
func (c *Connection) Write(b []byte) error {
|
||||
if len(b) == 0 {
|
||||
return nil
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
_, err := c.conn.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
// Subscribe add current connection into subscribers of the given channel
|
||||
func (c *Connection) Subscribe(channel string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.subs == nil {
|
||||
c.subs = make(map[string]bool)
|
||||
}
|
||||
c.subs[channel] = true
|
||||
}
|
||||
|
||||
// UnSubscribe removes current connection into subscribers of the given channel
|
||||
func (c *Connection) UnSubscribe(channel string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.subs == nil {
|
||||
return
|
||||
}
|
||||
delete(c.subs, channel)
|
||||
}
|
||||
|
||||
// SubsCount returns the number of subscribing channels
|
||||
func (c *Connection) SubsCount() int {
|
||||
if c.subs == nil {
|
||||
return 0
|
||||
}
|
||||
return len(c.subs)
|
||||
}
|
||||
|
||||
// GetChannels returns all subscribing channels
|
||||
func (c *Connection) GetChannels() []string {
|
||||
if c.subs == nil {
|
||||
return make([]string, 0)
|
||||
}
|
||||
channels := make([]string, len(c.subs))
|
||||
i := 0
|
||||
for channel := range c.subs {
|
||||
channels[i] = channel
|
||||
i++
|
||||
}
|
||||
return channels
|
||||
}
|
||||
50
redis/server/pubsub_test.go
Normal file
50
redis/server/pubsub_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/hdt3213/godis/lib/utils"
|
||||
"github.com/hdt3213/godis/pubsub"
|
||||
"github.com/hdt3213/godis/redis/connection"
|
||||
"github.com/hdt3213/godis/redis/parser"
|
||||
"github.com/hdt3213/godis/redis/reply/asserts"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPublish(t *testing.T) {
|
||||
hub := pubsub.MakeHub()
|
||||
channel := utils.RandString(5)
|
||||
msg := utils.RandString(5)
|
||||
conn := &connection.FakeConn{}
|
||||
pubsub.Subscribe(hub, conn, utils.ToBytesList(channel))
|
||||
conn.Clean() // clean subscribe success
|
||||
pubsub.Publish(hub, utils.ToBytesList(channel, msg))
|
||||
data := conn.Bytes()
|
||||
ret, err := parser.ParseOne(data)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
asserts.AssertMultiBulkReply(t, ret, []string{
|
||||
"message",
|
||||
channel,
|
||||
msg,
|
||||
})
|
||||
|
||||
// unsubscribe
|
||||
pubsub.UnSubscribe(hub, conn, utils.ToBytesList(channel))
|
||||
conn.Clean()
|
||||
pubsub.Publish(hub, utils.ToBytesList(channel, msg))
|
||||
data = conn.Bytes()
|
||||
if len(data) > 0 {
|
||||
t.Error("expect no msg")
|
||||
}
|
||||
|
||||
// unsubscribe all
|
||||
pubsub.Subscribe(hub, conn, utils.ToBytesList(channel))
|
||||
pubsub.UnSubscribe(hub, conn, utils.ToBytesList())
|
||||
conn.Clean()
|
||||
pubsub.Publish(hub, utils.ToBytesList(channel, msg))
|
||||
data = conn.Bytes()
|
||||
if len(data) > 0 {
|
||||
t.Error("expect no msg")
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/hdt3213/godis/interface/db"
|
||||
"github.com/hdt3213/godis/lib/logger"
|
||||
"github.com/hdt3213/godis/lib/sync/atomic"
|
||||
"github.com/hdt3213/godis/redis/connection"
|
||||
"github.com/hdt3213/godis/redis/parser"
|
||||
"github.com/hdt3213/godis/redis/reply"
|
||||
"io"
|
||||
@@ -45,7 +46,7 @@ func MakeHandler() *Handler {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) closeClient(client *Connection) {
|
||||
func (h *Handler) closeClient(client *connection.Connection) {
|
||||
_ = client.Close()
|
||||
h.db.AfterClientClose(client)
|
||||
h.activeConn.Delete(client)
|
||||
@@ -58,10 +59,10 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
client := NewConn(conn)
|
||||
client := connection.NewConn(conn)
|
||||
h.activeConn.Store(client, 1)
|
||||
|
||||
ch := parser.Parse(conn)
|
||||
ch := parser.ParseStream(conn)
|
||||
for payload := range ch {
|
||||
if payload.Err != nil {
|
||||
if payload.Err == io.EOF ||
|
||||
@@ -69,7 +70,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
strings.Contains(payload.Err.Error(), "use of closed network connection") {
|
||||
// connection closed
|
||||
h.closeClient(client)
|
||||
logger.Info("connection closed: " + client.conn.RemoteAddr().String())
|
||||
logger.Info("connection closed: " + client.RemoteAddr().String())
|
||||
return
|
||||
} else {
|
||||
// protocol err
|
||||
@@ -77,7 +78,7 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
err := client.Write(errReply.ToBytes())
|
||||
if err != nil {
|
||||
h.closeClient(client)
|
||||
logger.Info("connection closed: " + client.conn.RemoteAddr().String())
|
||||
logger.Info("connection closed: " + client.RemoteAddr().String())
|
||||
return
|
||||
}
|
||||
continue
|
||||
@@ -107,7 +108,7 @@ func (h *Handler) Close() error {
|
||||
h.closing.Set(true)
|
||||
// TODO: concurrent wait
|
||||
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
||||
client := key.(*Connection)
|
||||
client := key.(*connection.Connection)
|
||||
_ = client.Close()
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/hdt3213/godis/tcp"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestListenAndServe(t *testing.T) {
|
||||
@@ -39,4 +40,5 @@ func TestListenAndServe(t *testing.T) {
|
||||
return
|
||||
}
|
||||
closeChan <- struct{}{}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user