mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 08:46:56 +08:00
tiny fix
This commit is contained in:
@@ -12,7 +12,7 @@ import (
|
|||||||
type Connection struct {
|
type Connection struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
|
||||||
// waiting util reply finished
|
// waiting until reply finished
|
||||||
waitingReply wait.Wait
|
waitingReply wait.Wait
|
||||||
|
|
||||||
// lock while server sending response
|
// lock while server sending response
|
||||||
@@ -50,7 +50,11 @@ func (c *Connection) Write(b []byte) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
c.waitingReply.Add(1)
|
||||||
|
defer func() {
|
||||||
|
c.waitingReply.Done()
|
||||||
|
c.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
_, err := c.conn.Write(b)
|
_, err := c.conn.Write(b)
|
||||||
return err
|
return err
|
||||||
|
@@ -39,16 +39,15 @@ func ParseOne(data []byte) (redis.Reply, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type readState struct {
|
type readState struct {
|
||||||
downloading bool
|
readingMultiLine bool
|
||||||
expectedArgsCount int
|
expectedArgsCount int
|
||||||
receivedCount int
|
|
||||||
msgType byte
|
msgType byte
|
||||||
args [][]byte
|
args [][]byte
|
||||||
fixedLen int64
|
bulkLen int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *readState) finished() bool {
|
func (s *readState) finished() bool {
|
||||||
return s.expectedArgsCount > 0 && s.receivedCount == s.expectedArgsCount
|
return s.expectedArgsCount > 0 && len(s.args) == s.expectedArgsCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func parse0(reader io.Reader, ch chan<- *Payload) {
|
func parse0(reader io.Reader, ch chan<- *Payload) {
|
||||||
@@ -82,7 +81,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parse line
|
// parse line
|
||||||
if !state.downloading {
|
if !state.readingMultiLine {
|
||||||
// receive new response
|
// receive new response
|
||||||
if msg[0] == '*' {
|
if msg[0] == '*' {
|
||||||
// multi bulk reply
|
// multi bulk reply
|
||||||
@@ -110,7 +109,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
|
|||||||
state = readState{} // reset state
|
state = readState{} // reset state
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if state.fixedLen == -1 { // null bulk reply
|
if state.bulkLen == -1 { // null bulk reply
|
||||||
ch <- &Payload{
|
ch <- &Payload{
|
||||||
Data: &reply.NullBulkReply{},
|
Data: &reply.NullBulkReply{},
|
||||||
}
|
}
|
||||||
@@ -129,7 +128,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// receive following bulk reply
|
// receive following bulk reply
|
||||||
err = readBulkBody(msg, &state)
|
err = readBody(msg, &state)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ch <- &Payload{
|
ch <- &Payload{
|
||||||
Err: errors.New("protocol error: " + string(msg)),
|
Err: errors.New("protocol error: " + string(msg)),
|
||||||
@@ -158,7 +157,7 @@ func parse0(reader io.Reader, ch chan<- *Payload) {
|
|||||||
func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
|
func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
|
||||||
var msg []byte
|
var msg []byte
|
||||||
var err error
|
var err error
|
||||||
if state.fixedLen == 0 { // read normal line
|
if state.bulkLen == 0 { // read normal line
|
||||||
msg, err = bufReader.ReadBytes('\n')
|
msg, err = bufReader.ReadBytes('\n')
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, true, err
|
return nil, true, err
|
||||||
@@ -167,7 +166,7 @@ func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
|
|||||||
return nil, false, errors.New("protocol error: " + string(msg))
|
return nil, false, errors.New("protocol error: " + string(msg))
|
||||||
}
|
}
|
||||||
} else { // read bulk line (binary safe)
|
} else { // read bulk line (binary safe)
|
||||||
msg = make([]byte, state.fixedLen+2)
|
msg = make([]byte, state.bulkLen+2)
|
||||||
_, err = io.ReadFull(bufReader, msg)
|
_, err = io.ReadFull(bufReader, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, true, err
|
return nil, true, err
|
||||||
@@ -177,7 +176,7 @@ func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
|
|||||||
msg[len(msg)-1] != '\n' {
|
msg[len(msg)-1] != '\n' {
|
||||||
return nil, false, errors.New("protocol error: " + string(msg))
|
return nil, false, errors.New("protocol error: " + string(msg))
|
||||||
}
|
}
|
||||||
state.fixedLen = 0
|
state.bulkLen = 0
|
||||||
}
|
}
|
||||||
return msg, false, nil
|
return msg, false, nil
|
||||||
}
|
}
|
||||||
@@ -195,10 +194,9 @@ func parseMultiBulkHeader(msg []byte, state *readState) error {
|
|||||||
} else if expectedLine > 0 {
|
} else if expectedLine > 0 {
|
||||||
// first line of multi bulk reply
|
// first line of multi bulk reply
|
||||||
state.msgType = msg[0]
|
state.msgType = msg[0]
|
||||||
state.downloading = true
|
state.readingMultiLine = true
|
||||||
state.expectedArgsCount = int(expectedLine)
|
state.expectedArgsCount = int(expectedLine)
|
||||||
state.receivedCount = 0
|
state.args = make([][]byte, 0, expectedLine)
|
||||||
state.args = make([][]byte, expectedLine)
|
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return errors.New("protocol error: " + string(msg))
|
return errors.New("protocol error: " + string(msg))
|
||||||
@@ -207,18 +205,17 @@ func parseMultiBulkHeader(msg []byte, state *readState) error {
|
|||||||
|
|
||||||
func parseBulkHeader(msg []byte, state *readState) error {
|
func parseBulkHeader(msg []byte, state *readState) error {
|
||||||
var err error
|
var err error
|
||||||
state.fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
|
state.bulkLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("protocol error: " + string(msg))
|
return errors.New("protocol error: " + string(msg))
|
||||||
}
|
}
|
||||||
if state.fixedLen == -1 { // null bulk
|
if state.bulkLen == -1 { // null bulk
|
||||||
return nil
|
return nil
|
||||||
} else if state.fixedLen > 0 {
|
} else if state.bulkLen > 0 {
|
||||||
state.msgType = msg[0]
|
state.msgType = msg[0]
|
||||||
state.downloading = true
|
state.readingMultiLine = true
|
||||||
state.expectedArgsCount = 1
|
state.expectedArgsCount = 1
|
||||||
state.receivedCount = 0
|
state.args = make([][]byte, 0, 1)
|
||||||
state.args = make([][]byte, 1)
|
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return errors.New("protocol error: " + string(msg))
|
return errors.New("protocol error: " + string(msg))
|
||||||
@@ -253,23 +250,21 @@ func parseSingleLineReply(msg []byte) (redis.Reply, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// read the non-first lines of multi bulk reply or bulk reply
|
// read the non-first lines of multi bulk reply or bulk reply
|
||||||
func readBulkBody(msg []byte, state *readState) error {
|
func readBody(msg []byte, state *readState) error {
|
||||||
line := msg[0 : len(msg)-2]
|
line := msg[0 : len(msg)-2]
|
||||||
var err error
|
var err error
|
||||||
if line[0] == '$' {
|
if line[0] == '$' {
|
||||||
// bulk reply
|
// bulk reply
|
||||||
state.fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
state.bulkLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("protocol error: " + string(msg))
|
return errors.New("protocol error: " + string(msg))
|
||||||
}
|
}
|
||||||
if state.fixedLen <= 0 { // null bulk in multi bulks
|
if state.bulkLen <= 0 { // null bulk in multi bulks
|
||||||
state.args[state.receivedCount] = []byte{}
|
state.args = append(state.args, []byte{})
|
||||||
state.receivedCount++
|
state.bulkLen = 0
|
||||||
state.fixedLen = 0
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
state.args[state.receivedCount] = line
|
state.args = append(state.args, line)
|
||||||
state.receivedCount++
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@@ -44,13 +44,13 @@ func (c *EchoClient) Close() error {
|
|||||||
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
|
func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
|
||||||
if h.closing.Get() {
|
if h.closing.Get() {
|
||||||
// closing handler refuse new connection
|
// closing handler refuse new connection
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
client := &EchoClient{
|
client := &EchoClient{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
}
|
}
|
||||||
h.activeConn.Store(client, 1)
|
h.activeConn.Store(client, struct{}{})
|
||||||
|
|
||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
for {
|
for {
|
||||||
@@ -69,7 +69,7 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
|
|||||||
//logger.Info("sleeping")
|
//logger.Info("sleeping")
|
||||||
//time.Sleep(10 * time.Second)
|
//time.Sleep(10 * time.Second)
|
||||||
b := []byte(msg)
|
b := []byte(msg)
|
||||||
conn.Write(b)
|
_, _ = conn.Write(b)
|
||||||
client.Waiting.Done()
|
client.Waiting.Done()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,10 +78,9 @@ func (h *EchoHandler) Handle(ctx context.Context, conn net.Conn) {
|
|||||||
func (h *EchoHandler) Close() error {
|
func (h *EchoHandler) Close() error {
|
||||||
logger.Info("handler shutting down...")
|
logger.Info("handler shutting down...")
|
||||||
h.closing.Set(true)
|
h.closing.Set(true)
|
||||||
// TODO: concurrent wait
|
|
||||||
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
h.activeConn.Range(func(key interface{}, val interface{}) bool {
|
||||||
client := key.(*EchoClient)
|
client := key.(*EchoClient)
|
||||||
client.Close()
|
_ = client.Close()
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
|
@@ -9,7 +9,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"github.com/hdt3213/godis/interface/tcp"
|
"github.com/hdt3213/godis/interface/tcp"
|
||||||
"github.com/hdt3213/godis/lib/logger"
|
"github.com/hdt3213/godis/lib/logger"
|
||||||
"github.com/hdt3213/godis/lib/sync/atomic"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
@@ -50,11 +49,9 @@ func ListenAndServeWithSignal(cfg *Config, handler tcp.Handler) error {
|
|||||||
// ListenAndServe binds port and handle requests, blocking until close
|
// ListenAndServe binds port and handle requests, blocking until close
|
||||||
func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) {
|
func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan struct{}) {
|
||||||
// listen signal
|
// listen signal
|
||||||
var closing atomic.Boolean
|
|
||||||
go func() {
|
go func() {
|
||||||
<-closeChan
|
<-closeChan
|
||||||
logger.Info("shutting down...")
|
logger.Info("shutting down...")
|
||||||
closing.Set(true)
|
|
||||||
_ = listener.Close() // listener.Accept() will return err immediately
|
_ = listener.Close() // listener.Accept() will return err immediately
|
||||||
_ = handler.Close() // close connections
|
_ = handler.Close() // close connections
|
||||||
}()
|
}()
|
||||||
@@ -70,13 +67,7 @@ func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan
|
|||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if closing.Get() {
|
break
|
||||||
logger.Info("waiting disconnect...")
|
|
||||||
waitDone.Wait()
|
|
||||||
return // handler will be closed by defer
|
|
||||||
}
|
|
||||||
logger.Error(fmt.Sprintf("accept err: %v", err))
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
// handle
|
// handle
|
||||||
logger.Info("accept link")
|
logger.Info("accept link")
|
||||||
@@ -88,4 +79,5 @@ func ListenAndServe(listener net.Listener, handler tcp.Handler, closeChan <-chan
|
|||||||
handler.Handle(ctx, conn)
|
handler.Handle(ctx, conn)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
waitDone.Wait()
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user