diff --git a/redis/connection/conn.go b/redis/connection/conn.go index f7c12fe..023face 100644 --- a/redis/connection/conn.go +++ b/redis/connection/conn.go @@ -12,7 +12,7 @@ import ( type Connection struct { conn net.Conn - // waiting util reply finished + // waiting until reply finished waitingReply wait.Wait // lock while server sending response @@ -50,7 +50,11 @@ func (c *Connection) Write(b []byte) error { return nil } c.mu.Lock() - defer c.mu.Unlock() + c.waitingReply.Add(1) + defer func() { + c.waitingReply.Done() + c.mu.Unlock() + }() _, err := c.conn.Write(b) return err diff --git a/redis/parser/parser.go b/redis/parser/parser.go index 1c166cd..6c55b5b 100644 --- a/redis/parser/parser.go +++ b/redis/parser/parser.go @@ -39,16 +39,15 @@ func ParseOne(data []byte) (redis.Reply, error) { } type readState struct { - downloading bool + readingMultiLine bool expectedArgsCount int - receivedCount int msgType byte args [][]byte - fixedLen int64 + bulkLen int64 } func (s *readState) finished() bool { - return s.expectedArgsCount > 0 && s.receivedCount == s.expectedArgsCount + return s.expectedArgsCount > 0 && len(s.args) == s.expectedArgsCount } func parse0(reader io.Reader, ch chan<- *Payload) { @@ -82,7 +81,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) { } // parse line - if !state.downloading { + if !state.readingMultiLine { // receive new response if msg[0] == '*' { // multi bulk reply @@ -110,7 +109,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) { state = readState{} // reset state continue } - if state.fixedLen == -1 { // null bulk reply + if state.bulkLen == -1 { // null bulk reply ch <- &Payload{ Data: &reply.NullBulkReply{}, } @@ -129,7 +128,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) { } } else { // receive following bulk reply - err = readBulkBody(msg, &state) + err = readBody(msg, &state) if err != nil { ch <- &Payload{ Err: errors.New("protocol error: " + string(msg)), @@ -158,7 +157,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) { func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) { var msg []byte var err error - if state.fixedLen == 0 { // read normal line + if state.bulkLen == 0 { // read normal line msg, err = bufReader.ReadBytes('\n') if err != nil { return nil, true, err @@ -167,7 +166,7 @@ func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) { return nil, false, errors.New("protocol error: " + string(msg)) } } else { // read bulk line (binary safe) - msg = make([]byte, state.fixedLen+2) + msg = make([]byte, state.bulkLen+2) _, err = io.ReadFull(bufReader, msg) if err != nil { return nil, true, err @@ -177,7 +176,7 @@ func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) { msg[len(msg)-1] != '\n' { return nil, false, errors.New("protocol error: " + string(msg)) } - state.fixedLen = 0 + state.bulkLen = 0 } return msg, false, nil } @@ -195,10 +194,9 @@ func parseMultiBulkHeader(msg []byte, state *readState) error { } else if expectedLine > 0 { // first line of multi bulk reply state.msgType = msg[0] - state.downloading = true + state.readingMultiLine = true state.expectedArgsCount = int(expectedLine) - state.receivedCount = 0 - state.args = make([][]byte, expectedLine) + state.args = make([][]byte, 0, expectedLine) return nil } else { return errors.New("protocol error: " + string(msg)) @@ -207,18 +205,17 @@ func parseMultiBulkHeader(msg []byte, state *readState) error { func parseBulkHeader(msg []byte, state *readState) error { var err error - state.fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) + state.bulkLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) if err != nil { return errors.New("protocol error: " + string(msg)) } - if state.fixedLen == -1 { // null bulk + if state.bulkLen == -1 { // null bulk return nil - } else if state.fixedLen > 0 { + } else if state.bulkLen > 0 { state.msgType = msg[0] - state.downloading = true + state.readingMultiLine = true state.expectedArgsCount = 1 - state.receivedCount = 0 - state.args = make([][]byte, 1) + state.args = make([][]byte, 0, 1) return nil } else { return errors.New("protocol error: " + string(msg)) @@ -253,23 +250,21 @@ func parseSingleLineReply(msg []byte) (redis.Reply, error) { } // read the non-first lines of multi bulk reply or bulk reply -func readBulkBody(msg []byte, state *readState) error { +func readBody(msg []byte, state *readState) error { line := msg[0 : len(msg)-2] var err error if line[0] == '$' { // bulk reply - state.fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64) + state.bulkLen, err = strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { return errors.New("protocol error: " + string(msg)) } - if state.fixedLen <= 0 { // null bulk in multi bulks - state.args[state.receivedCount] = []byte{} - state.receivedCount++ - state.fixedLen = 0 + if state.bulkLen <= 0 { // null bulk in multi bulks + state.args = append(state.args, []byte{}) + state.bulkLen = 0 } } else { - state.args[state.receivedCount] = line - state.receivedCount++ + state.args = append(state.args, line) } return nil } diff --git a/tcp/echo.go b/tcp/echo.go index b2186a9..0314b66 100644 --- a/tcp/echo.go +++ b/tcp/echo.go @@ -44,13 +44,13 @@ func (c *EchoClient) Close() error { func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) { if h.closing.Get() { // closing handler refuse new connection - conn.Close() + _ = conn.Close() } client := &EchoClient{ Conn: conn, } - h.activeConn.Store(client, 1) + h.activeConn.Store(client, struct{}{}) reader := bufio.NewReader(conn) for { @@ -69,7 +69,7 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) { //logger.Info("sleeping") //time.Sleep(10 * time.Second) b := []byte(msg) - conn.Write(b) + _, _ = conn.Write(b) client.Waiting.Done() } } @@ -78,10 +78,9 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) { func (h *EchoHandler) Close() error { logger.Info("handler shutting down...") h.closing.Set(true) - // TODO: concurrent wait h.activeConn.Range(func(key interface{}, val interface{}) bool { client := key.(*EchoClient) - client.Close() + _ = client.Close() return true }) return nil diff --git a/tcp/server.go b/tcp/server.go index 12f1b52..9a3d6cd 100644 --- a/tcp/server.go +++ b/tcp/server.go @@ -9,7 +9,6 @@ import ( "fmt" "github.com/hdt3213/godis/interface/tcp" "github.com/hdt3213/godis/lib/logger" - "github.com/hdt3213/godis/lib/sync/atomic" "net" "os" "os/signal" @@ -50,11 +49,9 @@ func ListenAndServeWithSignal(cfg *Config, handler tcp.Handler) error { // ListenAndServe binds port and handle requests, blocking until close func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) { // listen signal - var closing atomic.Boolean go func() { <-closeChan logger.Info("shutting down...") - closing.Set(true) _ = listener.Close() // listener.Accept() will return err immediately _ = handler.Close() // close connections }() @@ -70,13 +67,7 @@ func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan for { conn, err := listener.Accept() if err != nil { - if closing.Get() { - logger.Info("waiting disconnect...") - waitDone.Wait() - return // handler will be closed by defer - } - logger.Error(fmt.Sprintf("accept err: %v", err)) - continue + break } // handle logger.Info("accept link") @@ -88,4 +79,5 @@ func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan handler.Handle(ctx, conn) }() } + waitDone.Wait() }