This commit is contained in:
hdt3213
2021-06-05 16:43:06 +08:00
parent 083d5b325e
commit 9d03314359
4 changed files with 34 additions and 44 deletions

View File

@@ -12,7 +12,7 @@ import (
type Connection struct { type Connection struct {
conn net.Conn conn net.Conn
// waiting util reply finished // waiting until reply finished
waitingReply wait.Wait waitingReply wait.Wait
// lock while server sending response // lock while server sending response
@@ -50,7 +50,11 @@ func (c *Connection) Write(b []byte) error {
return nil return nil
} }
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() c.waitingReply.Add(1)
defer func() {
c.waitingReply.Done()
c.mu.Unlock()
}()
_, err := c.conn.Write(b) _, err := c.conn.Write(b)
return err return err

View File

@@ -39,16 +39,15 @@ func ParseOne(data []byte) (redis.Reply, error) {
} }
type readState struct { type readState struct {
downloading bool readingMultiLine bool
expectedArgsCount int expectedArgsCount int
receivedCount int
msgType byte msgType byte
args [][]byte args [][]byte
fixedLen int64 bulkLen int64
} }
func (s *readState) finished() bool { func (s *readState) finished() bool {
return s.expectedArgsCount > 0 && s.receivedCount == s.expectedArgsCount return s.expectedArgsCount > 0 && len(s.args) == s.expectedArgsCount
} }
func parse0(reader io.Reader, ch chan<- *Payload) { func parse0(reader io.Reader, ch chan<- *Payload) {
@@ -82,7 +81,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
} }
// parse line // parse line
if !state.downloading { if !state.readingMultiLine {
// receive new response // receive new response
if msg[0] == '*' { if msg[0] == '*' {
// multi bulk reply // multi bulk reply
@@ -110,7 +109,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
state = readState{} // reset state state = readState{} // reset state
continue continue
} }
if state.fixedLen == -1 { // null bulk reply if state.bulkLen == -1 { // null bulk reply
ch <- &Payload{ ch <- &Payload{
Data: &reply.NullBulkReply{}, Data: &reply.NullBulkReply{},
} }
@@ -129,7 +128,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
} }
} else { } else {
// receive following bulk reply // receive following bulk reply
err = readBulkBody(msg, &state) err = readBody(msg, &state)
if err != nil { if err != nil {
ch <- &Payload{ ch <- &Payload{
Err: errors.New("protocol error: " + string(msg)), Err: errors.New("protocol error: " + string(msg)),
@@ -158,7 +157,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) { func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
var msg []byte var msg []byte
var err error var err error
if state.fixedLen == 0 { // read normal line if state.bulkLen == 0 { // read normal line
msg, err = bufReader.ReadBytes('\n') msg, err = bufReader.ReadBytes('\n')
if err != nil { if err != nil {
return nil, true, err return nil, true, err
@@ -167,7 +166,7 @@ func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
return nil, false, errors.New("protocol error: " + string(msg)) return nil, false, errors.New("protocol error: " + string(msg))
} }
} else { // read bulk line (binary safe) } else { // read bulk line (binary safe)
msg = make([]byte, state.fixedLen+2) msg = make([]byte, state.bulkLen+2)
_, err = io.ReadFull(bufReader, msg) _, err = io.ReadFull(bufReader, msg)
if err != nil { if err != nil {
return nil, true, err return nil, true, err
@@ -177,7 +176,7 @@ func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
msg[len(msg)-1] != '\n' { msg[len(msg)-1] != '\n' {
return nil, false, errors.New("protocol error: " + string(msg)) return nil, false, errors.New("protocol error: " + string(msg))
} }
state.fixedLen = 0 state.bulkLen = 0
} }
return msg, false, nil return msg, false, nil
} }
@@ -195,10 +194,9 @@ func parseMultiBulkHeader(msg []byte, state *readState) error {
} else if expectedLine > 0 { } else if expectedLine > 0 {
// first line of multi bulk reply // first line of multi bulk reply
state.msgType = msg[0] state.msgType = msg[0]
state.downloading = true state.readingMultiLine = true
state.expectedArgsCount = int(expectedLine) state.expectedArgsCount = int(expectedLine)
state.receivedCount = 0 state.args = make([][]byte, 0, expectedLine)
state.args = make([][]byte, expectedLine)
return nil return nil
} else { } else {
return errors.New("protocol error: " + string(msg)) return errors.New("protocol error: " + string(msg))
@@ -207,18 +205,17 @@ func parseMultiBulkHeader(msg []byte, state *readState) error {
func parseBulkHeader(msg []byte, state *readState) error { func parseBulkHeader(msg []byte, state *readState) error {
var err error var err error
state.fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) state.bulkLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
if err != nil { if err != nil {
return errors.New("protocol error: " + string(msg)) return errors.New("protocol error: " + string(msg))
} }
if state.fixedLen == -1 { // null bulk if state.bulkLen == -1 { // null bulk
return nil return nil
} else if state.fixedLen > 0 { } else if state.bulkLen > 0 {
state.msgType = msg[0] state.msgType = msg[0]
state.downloading = true state.readingMultiLine = true
state.expectedArgsCount = 1 state.expectedArgsCount = 1
state.receivedCount = 0 state.args = make([][]byte, 0, 1)
state.args = make([][]byte, 1)
return nil return nil
} else { } else {
return errors.New("protocol error: " + string(msg)) return errors.New("protocol error: " + string(msg))
@@ -253,23 +250,21 @@ func parseSingleLineReply(msg []byte) (redis.Reply, error) {
} }
// read the non-first lines of multi bulk reply or bulk reply // read the non-first lines of multi bulk reply or bulk reply
func readBulkBody(msg []byte, state *readState) error { func readBody(msg []byte, state *readState) error {
line := msg[0 : len(msg)-2] line := msg[0 : len(msg)-2]
var err error var err error
if line[0] == '$' { if line[0] == '$' {
// bulk reply // bulk reply
state.fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) state.bulkLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil { if err != nil {
return errors.New("protocol error: " + string(msg)) return errors.New("protocol error: " + string(msg))
} }
if state.fixedLen <= 0 { // null bulk in multi bulks if state.bulkLen <= 0 { // null bulk in multi bulks
state.args[state.receivedCount] = []byte{} state.args = append(state.args, []byte{})
state.receivedCount++ state.bulkLen = 0
state.fixedLen = 0
} }
} else { } else {
state.args[state.receivedCount] = line state.args = append(state.args, line)
state.receivedCount++
} }
return nil return nil
} }

View File

@@ -44,13 +44,13 @@ func (c *EchoClient) Close() error {
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) { func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() { if h.closing.Get() {
// closing handler refuse new connection // closing handler refuse new connection
conn.Close() _ = conn.Close()
} }
client := &EchoClient{ client := &EchoClient{
Conn: conn, Conn: conn,
} }
h.activeConn.Store(client, 1) h.activeConn.Store(client, struct{}{})
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
for { for {
@@ -69,7 +69,7 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
//logger.Info("sleeping") //logger.Info("sleeping")
//time.Sleep(10 * time.Second) //time.Sleep(10 * time.Second)
b := []byte(msg) b := []byte(msg)
conn.Write(b) _, _ = conn.Write(b)
client.Waiting.Done() client.Waiting.Done()
} }
} }
@@ -78,10 +78,9 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
func (h *EchoHandler) Close() error { func (h *EchoHandler) Close() error {
logger.Info("handler shutting down...") logger.Info("handler shutting down...")
h.closing.Set(true) h.closing.Set(true)
// TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{}) bool { h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*EchoClient) client := key.(*EchoClient)
client.Close() _ = client.Close()
return true return true
}) })
return nil return nil

View File

@@ -9,7 +9,6 @@ import (
"fmt" "fmt"
"github.com/hdt3213/godis/interface/tcp" "github.com/hdt3213/godis/interface/tcp"
"github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/sync/atomic"
"net" "net"
"os" "os"
"os/signal" "os/signal"
@@ -50,11 +49,9 @@ func ListenAndServeWithSignal(cfg *Config, handler tcp.Handler) error {
// ListenAndServe binds port and handle requests, blocking until close // ListenAndServe binds port and handle requests, blocking until close
func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) { func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) {
// listen signal // listen signal
var closing atomic.Boolean
go func() { go func() {
<-closeChan <-closeChan
logger.Info("shutting down...") logger.Info("shutting down...")
closing.Set(true)
_ = listener.Close() // listener.Accept() will return err immediately _ = listener.Close() // listener.Accept() will return err immediately
_ = handler.Close() // close connections _ = handler.Close() // close connections
}() }()
@@ -70,13 +67,7 @@ func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
if closing.Get() { break
logger.Info("waiting disconnect...")
waitDone.Wait()
return // handler will be closed by defer
}
logger.Error(fmt.Sprintf("accept err: %v", err))
continue
} }
// handle // handle
logger.Info("accept link") logger.Info("accept link")
@@ -88,4 +79,5 @@ func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan
handler.Handle(ctx, conn) handler.Handle(ctx, conn)
}() }()
} }
waitDone.Wait()
} }