This commit is contained in:
hdt3213
2021-06-05 16:43:06 +08:00
parent 083d5b325e
commit 9d03314359
4 changed files with 34 additions and 44 deletions

View File

@@ -12,7 +12,7 @@ import (
type Connection struct {
conn net.Conn
// waiting util reply finished
// waiting until reply finished
waitingReply wait.Wait
// lock while server sending response
@@ -50,7 +50,11 @@ func (c *Connection) Write(b []byte) error {
return nil
}
c.mu.Lock()
defer c.mu.Unlock()
c.waitingReply.Add(1)
defer func() {
c.waitingReply.Done()
c.mu.Unlock()
}()
_, err := c.conn.Write(b)
return err

View File

@@ -39,16 +39,15 @@ func ParseOne(data []byte) (redis.Reply, error) {
}
type readState struct {
downloading bool
readingMultiLine bool
expectedArgsCount int
receivedCount int
msgType byte
args [][]byte
fixedLen int64
bulkLen int64
}
func (s *readState) finished() bool {
return s.expectedArgsCount > 0 && s.receivedCount == s.expectedArgsCount
return s.expectedArgsCount > 0 && len(s.args) == s.expectedArgsCount
}
func parse0(reader io.Reader, ch chan<- *Payload) {
@@ -82,7 +81,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
}
// parse line
if !state.downloading {
if !state.readingMultiLine {
// receive new response
if msg[0] == '*' {
// multi bulk reply
@@ -110,7 +109,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
state = readState{} // reset state
continue
}
if state.fixedLen == -1 { // null bulk reply
if state.bulkLen == -1 { // null bulk reply
ch <- &Payload{
Data: &reply.NullBulkReply{},
}
@@ -129,7 +128,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
}
} else {
// receive following bulk reply
err = readBulkBody(msg, &state)
err = readBody(msg, &state)
if err != nil {
ch <- &Payload{
Err: errors.New("protocol error: " + string(msg)),
@@ -158,7 +157,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
var msg []byte
var err error
if state.fixedLen == 0 { // read normal line
if state.bulkLen == 0 { // read normal line
msg, err = bufReader.ReadBytes('\n')
if err != nil {
return nil, true, err
@@ -167,7 +166,7 @@ func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
return nil, false, errors.New("protocol error: " + string(msg))
}
} else { // read bulk line (binary safe)
msg = make([]byte, state.fixedLen+2)
msg = make([]byte, state.bulkLen+2)
_, err = io.ReadFull(bufReader, msg)
if err != nil {
return nil, true, err
@@ -177,7 +176,7 @@ func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
msg[len(msg)-1] != '\n' {
return nil, false, errors.New("protocol error: " + string(msg))
}
state.fixedLen = 0
state.bulkLen = 0
}
return msg, false, nil
}
@@ -195,10 +194,9 @@ func parseMultiBulkHeader(msg []byte, state *readState) error {
} else if expectedLine > 0 {
// first line of multi bulk reply
state.msgType = msg[0]
state.downloading = true
state.readingMultiLine = true
state.expectedArgsCount = int(expectedLine)
state.receivedCount = 0
state.args = make([][]byte, expectedLine)
state.args = make([][]byte, 0, expectedLine)
return nil
} else {
return errors.New("protocol error: " + string(msg))
@@ -207,18 +205,17 @@ func parseMultiBulkHeader(msg []byte, state *readState) error {
func parseBulkHeader(msg []byte, state *readState) error {
var err error
state.fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
state.bulkLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
if err != nil {
return errors.New("protocol error: " + string(msg))
}
if state.fixedLen == -1 { // null bulk
if state.bulkLen == -1 { // null bulk
return nil
} else if state.fixedLen > 0 {
} else if state.bulkLen > 0 {
state.msgType = msg[0]
state.downloading = true
state.readingMultiLine = true
state.expectedArgsCount = 1
state.receivedCount = 0
state.args = make([][]byte, 1)
state.args = make([][]byte, 0, 1)
return nil
} else {
return errors.New("protocol error: " + string(msg))
@@ -253,23 +250,21 @@ func parseSingleLineReply(msg []byte) (redis.Reply, error) {
}
// read the non-first lines of multi bulk reply or bulk reply
func readBulkBody(msg []byte, state *readState) error {
func readBody(msg []byte, state *readState) error {
line := msg[0 : len(msg)-2]
var err error
if line[0] == '$' {
// bulk reply
state.fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
state.bulkLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
return errors.New("protocol error: " + string(msg))
}
if state.fixedLen <= 0 { // null bulk in multi bulks
state.args[state.receivedCount] = []byte{}
state.receivedCount++
state.fixedLen = 0
if state.bulkLen <= 0 { // null bulk in multi bulks
state.args = append(state.args, []byte{})
state.bulkLen = 0
}
} else {
state.args[state.receivedCount] = line
state.receivedCount++
state.args = append(state.args, line)
}
return nil
}

View File

@@ -44,13 +44,13 @@ func (c *EchoClient) Close() error {
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
if h.closing.Get() {
// closing handler refuse new connection
conn.Close()
_ = conn.Close()
}
client := &EchoClient{
Conn: conn,
}
h.activeConn.Store(client, 1)
h.activeConn.Store(client, struct{}{})
reader := bufio.NewReader(conn)
for {
@@ -69,7 +69,7 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
//logger.Info("sleeping")
//time.Sleep(10 * time.Second)
b := []byte(msg)
conn.Write(b)
_, _ = conn.Write(b)
client.Waiting.Done()
}
}
@@ -78,10 +78,9 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
func (h *EchoHandler) Close() error {
logger.Info("handler shutting down...")
h.closing.Set(true)
// TODO: concurrent wait
h.activeConn.Range(func(key interface{}, val interface{}) bool {
client := key.(*EchoClient)
client.Close()
_ = client.Close()
return true
})
return nil

View File

@@ -9,7 +9,6 @@ import (
"fmt"
"github.com/hdt3213/godis/interface/tcp"
"github.com/hdt3213/godis/lib/logger"
"github.com/hdt3213/godis/lib/sync/atomic"
"net"
"os"
"os/signal"
@@ -50,11 +49,9 @@ func ListenAndServeWithSignal(cfg *Config, handler tcp.Handler) error {
// ListenAndServe binds port and handle requests, blocking until close
func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) {
// listen signal
var closing atomic.Boolean
go func() {
<-closeChan
logger.Info("shutting down...")
closing.Set(true)
_ = listener.Close() // listener.Accept() will return err immediately
_ = handler.Close() // close connections
}()
@@ -70,13 +67,7 @@ func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan
for {
conn, err := listener.Accept()
if err != nil {
if closing.Get() {
logger.Info("waiting disconnect...")
waitDone.Wait()
return // handler will be closed by defer
}
logger.Error(fmt.Sprintf("accept err: %v", err))
continue
break
}
// handle
logger.Info("accept link")
@@ -88,4 +79,5 @@ func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan
handler.Handle(ctx, conn)
}()
}
waitDone.Wait()
}