refactor RESP parser

This commit is contained in:
hdt3213
2021-04-18 22:33:19 +08:00
parent 1eb55cb5fb
commit 3fa3d32e1e
10 changed files with 474 additions and 414 deletions

View File

@@ -1,7 +1,6 @@
package db
import (
"bufio"
"github.com/hdt3213/godis/src/config"
"github.com/hdt3213/godis/src/datastruct/dict"
List "github.com/hdt3213/godis/src/datastruct/list"
@@ -9,6 +8,8 @@ import (
"github.com/hdt3213/godis/src/datastruct/set"
SortedSet "github.com/hdt3213/godis/src/datastruct/sortedset"
"github.com/hdt3213/godis/src/lib/logger"
"github.com/hdt3213/godis/src/lib/utils"
"github.com/hdt3213/godis/src/redis/parser"
"github.com/hdt3213/godis/src/redis/reply"
"io"
"io/ioutil"
@@ -72,6 +73,7 @@ func trim(msg []byte) string {
return trimmed
}
// read aof file
func (db *DB) loadAof(maxBytes int) {
// delete aofChan to prevent write again
@@ -91,96 +93,29 @@ func (db *DB) loadAof(maxBytes int) {
}
defer file.Close()
reader := bufio.NewReader(file)
var fixedLen int64 = 0
var expectedArgsCount uint32
var receivedCount uint32
var args [][]byte
processing := false
var msg []byte
readBytes := 0
for {
if maxBytes != 0 && readBytes >= maxBytes {
break
reader := utils.NewLimitedReader(file, maxBytes)
ch := parser.Parse(reader)
for p := range ch {
if p.Err != nil {
if p.Err == io.EOF {
break
}
logger.Error("parse error: " + p.Err.Error())
continue
}
if fixedLen == 0 {
msg, err = reader.ReadBytes('\n')
if err == io.EOF {
return
}
if len(msg) == 0 {
logger.Warn("invalid format: line should end with \\r\\n")
return
}
readBytes += len(msg)
} else {
msg = make([]byte, fixedLen+2)
n, err := io.ReadFull(reader, msg)
if err == io.EOF {
return
}
if len(msg) == 0 {
logger.Warn("invalid multibulk length")
return
}
fixedLen = 0
readBytes += n
if p.Data == nil {
logger.Error("empty payload")
continue
}
if err != nil {
logger.Warn(err)
return
r, ok := p.Data.(*reply.MultiBulkReply)
if !ok {
logger.Error("require multi bulk reply")
continue
}
if !processing {
// new request
if msg[0] == '*' {
// bulk multi msg
expectedLine, err := strconv.ParseUint(trim(msg[1:]), 10, 32)
if err != nil {
logger.Warn(err)
return
}
expectedArgsCount = uint32(expectedLine)
receivedCount = 0
processing = true
args = make([][]byte, expectedLine)
} else {
logger.Warn("msg should start with '*'")
return
}
} else {
// receive following part of a request
line := msg[0 : len(msg)-2]
if line[0] == '$' {
fixedLen, err = strconv.ParseInt(trim(line[1:]), 10, 64)
if err != nil {
logger.Warn(err)
return
}
if fixedLen <= 0 {
logger.Warn("invalid multibulk length")
return
}
} else {
args[receivedCount] = line
receivedCount++
}
// if sending finished
if receivedCount == expectedArgsCount {
processing = false
cmd := strings.ToLower(string(args[0]))
cmdFunc, ok := router[cmd]
if ok {
cmdFunc(db, args[1:])
}
// finish
expectedArgsCount = 0
receivedCount = 0
args = nil
}
cmd := strings.ToLower(string(r.Args[0]))
cmdFunc, ok := router[cmd]
if ok {
cmdFunc(db, r.Args[1:])
}
}
}

View File

@@ -37,6 +37,12 @@ const (
FATAL
)
const flags = log.LstdFlags
func init() {
logger = log.New(os.Stdout, DefaultPrefix, flags)
}
func Setup(settings *Settings) {
var err error
dir := settings.Path
@@ -51,7 +57,7 @@ func Setup(settings *Settings) {
}
mw := io.MultiWriter(os.Stdout, logFile)
logger = log.New(mw, DefaultPrefix, log.LstdFlags)
logger = log.New(mw, DefaultPrefix, flags)
}
func setPrefix(level Level) {

View File

@@ -1,21 +0,0 @@
package gob
import (
"bytes"
"encoding/gob"
)
func Marshal(obj interface{}) ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
err := enc.Encode(obj)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func UnMarshal(data []byte, obj interface{}) error {
dec := gob.NewDecoder(bytes.NewBuffer(data))
return dec.Decode(obj)
}

View File

@@ -0,0 +1,34 @@
package utils
import (
"errors"
"io"
)
type LimitedReader struct {
src io.Reader
n int
limit int
}
func NewLimitedReader(src io.Reader, limit int) *LimitedReader {
return &LimitedReader{
src: src,
limit: limit,
}
}
func (r *LimitedReader) Read(p []byte) (n int, err error) {
if r.src == nil {
return 0, errors.New("no data source")
}
if r.limit > 0 && r.n >= r.limit {
return 0, io.EOF
}
n, err = r.src.Read(p)
if err != nil {
return n, err
}
r.n += n
return
}

View File

@@ -1,31 +1,25 @@
package client
import (
"bufio"
"context"
"errors"
"github.com/hdt3213/godis/src/interface/redis"
"github.com/hdt3213/godis/src/lib/logger"
"github.com/hdt3213/godis/src/lib/sync/wait"
"github.com/hdt3213/godis/src/redis/parser"
"github.com/hdt3213/godis/src/redis/reply"
"io"
"net"
"strconv"
"strings"
"runtime/debug"
"sync"
"time"
)
type Client struct {
conn net.Conn
sendingReqs chan *Request // waiting sending
pendingReqs chan *Request // wait to send
waitingReqs chan *Request // waiting response
ticker *time.Ticker
addr string
ctx context.Context
cancelFunc context.CancelFunc
writing *sync.WaitGroup
working *sync.WaitGroup // its counter presents unfinished requests(pending and waiting)
}
type Request struct {
@@ -47,15 +41,12 @@ func MakeClient(addr string) (*Client, error) {
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
return &Client{
addr: addr,
conn: conn,
sendingReqs: make(chan *Request, chanSize),
pendingReqs: make(chan *Request, chanSize),
waitingReqs: make(chan *Request, chanSize),
ctx: ctx,
cancelFunc: cancel,
writing: &sync.WaitGroup{},
working: &sync.WaitGroup{},
}, nil
}
@@ -64,20 +55,22 @@ func (client *Client) Start() {
go client.handleWrite()
go func() {
err := client.handleRead()
logger.Warn(err)
if err != nil {
logger.Error(err)
}
}()
go client.heartbeat()
}
func (client *Client) Close() {
client.ticker.Stop()
// stop new request
close(client.sendingReqs)
close(client.pendingReqs)
// wait stop process
client.writing.Wait()
client.working.Wait()
// clean
client.cancelFunc()
_ = client.conn.Close()
close(client.waitingReqs)
}
@@ -106,34 +99,17 @@ func (client *Client) handleConnectionError(err error) error {
}
func (client *Client) heartbeat() {
loop:
for {
select {
case <-client.ticker.C:
client.sendingReqs <- &Request{
args: [][]byte{[]byte("PING")},
heartbeat: true,
}
case <-client.ctx.Done():
break loop
}
for range client.ticker.C {
client.doHeartbeat()
}
}
func (client *Client) handleWrite() {
loop:
for {
select {
case req := <-client.sendingReqs:
client.writing.Add(1)
client.doRequest(req)
case <-client.ctx.Done():
break loop
}
for req := range client.pendingReqs {
client.doRequest(req)
}
}
// todo: wait with timeout
func (client *Client) Send(args [][]byte) redis.Reply {
request := &Request{
args: args,
@@ -141,7 +117,9 @@ func (client *Client) Send(args [][]byte) redis.Reply {
waiting: &wait.Wait{},
}
request.waiting.Add(1)
client.sendingReqs <- request
client.working.Add(1)
defer client.working.Done()
client.pendingReqs <- request
timeout := request.waiting.WaitWithTimeout(maxWait)
if timeout {
return reply.MakeErrReply("server time out")
@@ -152,8 +130,24 @@ func (client *Client) Send(args [][]byte) redis.Reply {
return request.reply
}
func (client *Client) doHeartbeat() {
request := &Request{
args: [][]byte{[]byte("PING")},
heartbeat: true,
}
request.waiting.Add(1)
client.working.Add(1)
defer client.working.Done()
client.pendingReqs <- request
request.waiting.WaitWithTimeout(maxWait)
}
func (client *Client) doRequest(req *Request) {
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
if req == nil || len(req.args) == 0 {
return
}
re := reply.MakeMultiBulkReply(req.args)
bytes := re.ToBytes()
_, err := client.conn.Write(bytes)
i := 0
for err != nil && i < 3 {
@@ -168,154 +162,34 @@ func (client *Client) doRequest(req *Request) {
} else {
req.err = err
req.waiting.Done()
client.writing.Done()
}
}
func (client *Client) finishRequest(reply redis.Reply) {
defer func() {
if err := recover(); err != nil {
debug.PrintStack()
logger.Error(err)
}
}()
request := <-client.waitingReqs
if request == nil {
return
}
request.reply = reply
if request.waiting != nil {
request.waiting.Done()
}
client.writing.Done()
}
func (client *Client) handleRead() error {
reader := bufio.NewReader(client.conn)
downloading := false
expectedArgsCount := 0
receivedCount := 0
msgType := byte(0) // first char of msg
var args [][]byte
var fixedLen int64 = 0
var err error
var msg []byte
for {
// read line
if fixedLen == 0 { // read normal line
msg, err = reader.ReadBytes('\n')
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
logger.Info("connection close")
} else {
logger.Warn(err)
}
return errors.New("connection closed")
}
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
return errors.New("protocol error")
}
} else { // read bulk line (binary safe)
msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg)
if err != nil {
if err == io.EOF || err == io.ErrUnexpectedEOF {
return errors.New("connection closed")
} else {
return err
}
}
if len(msg) == 0 ||
msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' {
return errors.New("protocol error")
}
fixedLen = 0
}
// parse line
if !downloading {
// receive new response
if msg[0] == '*' { // multi bulk response
// bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil {
return errors.New("protocol error: " + err.Error())
}
if expectedLine == 0 {
client.finishRequest(&reply.EmptyMultiBulkReply{})
} else if expectedLine > 0 {
msgType = msg[0]
downloading = true
expectedArgsCount = int(expectedLine)
receivedCount = 0
args = make([][]byte, expectedLine)
} else {
return errors.New("protocol error")
}
} else if msg[0] == '$' { // bulk response
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
if err != nil {
return err
}
if fixedLen == -1 { // null bulk
client.finishRequest(&reply.NullBulkReply{})
fixedLen = 0
} else if fixedLen > 0 {
msgType = msg[0]
downloading = true
expectedArgsCount = 1
receivedCount = 0
args = make([][]byte, 1)
} else {
return errors.New("protocol error")
}
} else { // single line response
str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r")
var result redis.Reply
switch msg[0] {
case '+':
result = reply.MakeStatusReply(str[1:])
case '-':
result = reply.MakeErrReply(str[1:])
case ':':
val, err := strconv.ParseInt(str[1:], 10, 64)
if err != nil {
return errors.New("protocol error")
}
result = reply.MakeIntReply(val)
}
client.finishRequest(result)
}
} else {
// receive following part of a request
line := msg[0 : len(msg)-2]
if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
return err
}
if fixedLen <= 0 { // null bulk in multi bulks
args[receivedCount] = []byte{}
receivedCount++
fixedLen = 0
}
} else {
args[receivedCount] = line
receivedCount++
}
// if sending finished
if receivedCount == expectedArgsCount {
downloading = false // finish downloading progress
if msgType == '*' {
reply := reply.MakeMultiBulkReply(args)
client.finishRequest(reply)
} else if msgType == '$' {
reply := reply.MakeBulkReply(args[0])
client.finishRequest(reply)
}
// finish reply
expectedArgsCount = 0
receivedCount = 0
args = nil
msgType = byte(0)
}
ch := parser.Parse(client.conn)
for payload := range ch {
if payload.Err != nil {
client.finishRequest(reply.MakeErrReply(payload.Err.Error()))
continue
}
client.finishRequest(payload.Data)
}
return nil
}

262
src/redis/parser/parser.go Normal file
View File

@@ -0,0 +1,262 @@
package parser
import (
"bufio"
"errors"
"github.com/hdt3213/godis/src/interface/redis"
"github.com/hdt3213/godis/src/lib/logger"
"github.com/hdt3213/godis/src/redis/reply"
"io"
"runtime/debug"
"strconv"
"strings"
)
type Payload struct {
Data redis.Reply
Err error
}
func Parse(reader io.Reader) <-chan *Payload {
ch := make(chan *Payload)
go func() {
defer func() {
if err := recover(); err != nil {
logger.Error(debug.Stack())
}
}()
parse0(reader, ch)
}()
return ch
}
type readState struct {
downloading bool
expectedArgsCount int
receivedCount int
msgType byte
args [][]byte
fixedLen int64
}
func (s *readState) finished() bool {
return s.expectedArgsCount > 0 && s.receivedCount == s.expectedArgsCount
}
func parse0(reader io.Reader, ch chan<- *Payload) {
bufReader := bufio.NewReader(reader)
var state readState
var err error
var msg []byte
for {
// read line
var ioErr bool
msg, ioErr, err = readLine(bufReader, &state)
if err != nil {
if ioErr { // encounter io err, stop read
ch <- &Payload{
Err: err,
}
close(ch)
return
} else { // protocol err, reset read state
ch <- &Payload{
Err: err,
}
state = readState{}
continue
}
}
// parse line
if !state.downloading {
// receive new response
if msg[0] == '*' {
// multi bulk reply
err = parseMultiBulkHeader(msg, &state)
if err != nil {
ch <- &Payload{
Err: errors.New("protocol error: " + string(msg)),
}
state = readState{} // reset state
continue
}
if state.expectedArgsCount == 0 {
ch <- &Payload{
Data: &reply.EmptyMultiBulkReply{},
}
state = readState{} // reset state
continue
}
} else if msg[0] == '$' { // bulk reply
err = parseBulkHeader(msg, &state)
if err != nil {
ch <- &Payload{
Err: errors.New("protocol error: " + string(msg)),
}
state = readState{} // reset state
continue
}
if state.fixedLen == -1 { // null bulk reply
ch <- &Payload{
Data: &reply.NullBulkReply{},
}
state = readState{} // reset state
continue
}
} else {
// single line reply
result, err := parseSingleLineReply(msg)
ch <- &Payload{
Data: result,
Err: err,
}
state = readState{} // reset state
continue
}
} else {
// receive following bulk reply
err = readBulkBody(msg, &state)
if err != nil {
ch <- &Payload{
Err: errors.New("protocol error: " + string(msg)),
}
state = readState{} // reset state
continue
}
// if sending finished
if state.finished() {
var result redis.Reply
if state.msgType == '*' {
result = reply.MakeMultiBulkReply(state.args)
} else if state.msgType == '$' {
result = reply.MakeBulkReply(state.args[0])
}
ch <- &Payload{
Data: result,
Err: err,
}
state = readState{}
}
}
}
}
func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
var msg []byte
var err error
if state.fixedLen == 0 { // read normal line
msg, err = bufReader.ReadBytes('\n')
if err != nil {
return nil, true, err
}
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
return nil, false, errors.New("protocol error: " + string(msg))
}
} else { // read bulk line (binary safe)
msg = make([]byte, state.fixedLen+2)
_, err = io.ReadFull(bufReader, msg)
if err != nil {
return nil, true, err
}
if len(msg) == 0 ||
msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' {
return nil, false, errors.New("protocol error: " + string(msg))
}
state.fixedLen = 0
}
return msg, false, nil
}
func parseMultiBulkHeader(msg []byte, state *readState) error {
var err error
var expectedLine uint64
expectedLine, err = strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil {
return errors.New("protocol error: " + string(msg))
}
if expectedLine == 0 {
state.expectedArgsCount = 0
return nil
} else if expectedLine > 0 {
// first line of multi bulk reply
state.msgType = msg[0]
state.downloading = true
state.expectedArgsCount = int(expectedLine)
state.receivedCount = 0
state.args = make([][]byte, expectedLine)
return nil
} else {
return errors.New("protocol error: " + string(msg))
}
}
func parseBulkHeader(msg []byte, state *readState) error {
var err error
state.fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
if err != nil {
return errors.New("protocol error: " + string(msg))
}
if state.fixedLen == -1 { // null bulk
return nil
} else if state.fixedLen > 0 {
state.msgType = msg[0]
state.downloading = true
state.expectedArgsCount = 1
state.receivedCount = 0
state.args = make([][]byte, 1)
return nil
} else {
return errors.New("protocol error: " + string(msg))
}
}
func parseSingleLineReply(msg []byte) (redis.Reply, error) {
str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r")
var result redis.Reply
switch msg[0] {
case '+': // status reply
result = reply.MakeStatusReply(str[1:])
case '-': // err reply
result = reply.MakeErrReply(str[1:])
case ':': // int reply
val, err := strconv.ParseInt(str[1:], 10, 64)
if err != nil {
return nil, errors.New("protocol error: " + string(msg))
}
result = reply.MakeIntReply(val)
default:
// parse as text protocol
strs := strings.Split(str, " ")
args := make([][]byte, len(strs))
for i, s := range strs {
args[i] = []byte(s)
}
result = reply.MakeMultiBulkReply(args)
}
return result, nil
}
// read the non-first lines of multi bulk reply or bulk reply
func readBulkBody(msg []byte, state *readState) error {
line := msg[0 : len(msg)-2]
var err error
if line[0] == '$' {
// bulk reply
state.fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
if err != nil {
return errors.New("protocol error: " + string(msg))
}
if state.fixedLen <= 0 { // null bulk in multi bulks
state.args[state.receivedCount] = []byte{}
state.receivedCount++
state.fixedLen = 0
}
} else {
state.args[state.receivedCount] = line
state.receivedCount++
}
return nil
}

View File

@@ -0,0 +1,56 @@
package parser
import (
"bytes"
"github.com/hdt3213/godis/src/datastruct/utils"
"github.com/hdt3213/godis/src/interface/redis"
"github.com/hdt3213/godis/src/redis/reply"
"io"
"testing"
)
func TestParse(t *testing.T) {
replies := []redis.Reply{
reply.MakeIntReply(1),
reply.MakeStatusReply("OK"),
reply.MakeErrReply("ERR unknown"),
reply.MakeBulkReply([]byte("a\r\nb")), // test binary safe
reply.MakeNullBulkReply(),
reply.MakeMultiBulkReply([][]byte{
[]byte("a"),
[]byte("\r\n"),
}),
reply.MakeEmptyMultiBulkReply(),
}
reqs := bytes.Buffer{}
for _, re := range replies {
reqs.Write(re.ToBytes())
}
reqs.Write([]byte("set a a" + reply.CRLF)) // test text protocol
expected := make([]redis.Reply, len(replies))
copy(expected, replies)
expected = append(expected, reply.MakeMultiBulkReply([][]byte{
[]byte("set"), []byte("a"), []byte("a"),
}))
ch := Parse(bytes.NewReader(reqs.Bytes()))
i := 0
for payload := range ch {
if payload.Err != nil {
if payload.Err == io.EOF {
return
}
t.Error(payload.Err)
return
}
if payload.Data == nil {
t.Error("empty data")
return
}
exp := expected[i]
i++
if !utils.BytesEquals(exp.ToBytes(), payload.Data.ToBytes()) {
t.Error("parse failed: " + string(exp.ToBytes()))
}
}
}

View File

@@ -24,6 +24,10 @@ func (r *NullBulkReply) ToBytes() []byte {
return nullBulkBytes
}
func MakeNullBulkReply() *NullBulkReply {
return &NullBulkReply{}
}
var emptyMultiBulkBytes = []byte("*0\r\n")
type EmptyMultiBulkReply struct{}
@@ -32,6 +36,10 @@ func (r *EmptyMultiBulkReply) ToBytes() []byte {
return emptyMultiBulkBytes
}
func MakeEmptyMultiBulkReply() *EmptyMultiBulkReply {
return &EmptyMultiBulkReply{}
}
// reply nothing, for commands like subscribe
type NoReply struct{}

View File

@@ -1,7 +1,6 @@
package server
import (
"github.com/hdt3213/godis/src/lib/sync/atomic"
"github.com/hdt3213/godis/src/lib/sync/wait"
"net"
"sync"
@@ -15,15 +14,6 @@ type Client struct {
// waiting util reply finished
waitingReply wait.Wait
// is sending request in progress
uploading atomic.AtomicBool
// multi bulk msg lineCount - 1(first line)
expectedArgsCount uint32
// sent line count, exclude first line
receivedCount uint32
// sent lines, exclude first line
args [][]byte
// lock while server sending response
mu sync.Mutex

View File

@@ -5,7 +5,6 @@ package server
*/
import (
"bufio"
"context"
"github.com/hdt3213/godis/src/cluster"
"github.com/hdt3213/godis/src/config"
@@ -13,10 +12,10 @@ import (
"github.com/hdt3213/godis/src/interface/db"
"github.com/hdt3213/godis/src/lib/logger"
"github.com/hdt3213/godis/src/lib/sync/atomic"
"github.com/hdt3213/godis/src/redis/parser"
"github.com/hdt3213/godis/src/redis/reply"
"io"
"net"
"strconv"
"strings"
"sync"
)
@@ -59,126 +58,43 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
client := MakeClient(conn)
h.activeConn.Store(client, 1)
reader := bufio.NewReader(conn)
var fixedLen int64 = 0
var err error
var msg []byte
for {
if fixedLen == 0 {
msg, err = reader.ReadBytes('\n')
if err != nil {
if err == io.EOF ||
err == io.ErrUnexpectedEOF ||
strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("connection close")
} else {
logger.Warn(err)
}
// after client close
ch := parser.Parse(conn)
for payload := range ch {
if payload.Err != nil {
if payload.Err == io.EOF ||
payload.Err == io.ErrUnexpectedEOF ||
strings.Contains(payload.Err.Error(), "use of closed network connection") {
// connection closed
h.closeClient(client)
return // io error, disconnect with client
}
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
} else {
msg = make([]byte, fixedLen+2)
_, err = io.ReadFull(reader, msg)
if err != nil {
if err == io.EOF ||
err == io.ErrUnexpectedEOF ||
strings.Contains(err.Error(), "use of closed network connection") {
logger.Info("connection close")
} else {
logger.Warn(err)
}
// after client close
h.closeClient(client)
return // io error, disconnect with client
}
if len(msg) == 0 ||
msg[len(msg)-2] != '\r' ||
msg[len(msg)-1] != '\n' {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
fixedLen = 0
}
if !client.uploading.Get() {
// new request
if msg[0] == '*' {
// bulk multi msg
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
if err != nil {
_, _ = client.conn.Write(UnknownErrReplyBytes)
continue
}
client.waitingReply.Add(1)
client.uploading.Set(true)
client.expectedArgsCount = uint32(expectedLine)
client.receivedCount = 0
client.args = make([][]byte, expectedLine)
logger.Info("connection closed: " + client.conn.RemoteAddr().String())
return
} else {
// text protocol
// remove \r or \n or \r\n in the end of line
str := strings.TrimSuffix(string(msg), "\n")
str = strings.TrimSuffix(str, "\r")
strs := strings.Split(str, " ")
args := make([][]byte, len(strs))
for i, s := range strs {
args[i] = []byte(s)
}
// send reply
result := h.db.Exec(client, args)
if result != nil {
_ = client.Write(result.ToBytes())
} else {
_ = client.Write(UnknownErrReplyBytes)
}
}
} else {
// receive following part of a request
line := msg[0 : len(msg)-2]
if line[0] == '$' {
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
// protocol err
errReply := reply.MakeErrReply(payload.Err.Error())
err := client.Write(errReply.ToBytes())
if err != nil {
errReply := &reply.ProtocolErrReply{Msg: err.Error()}
_, _ = client.conn.Write(errReply.ToBytes())
h.closeClient(client)
logger.Info("connection closed: " + client.conn.RemoteAddr().String())
return
}
if fixedLen <= 0 {
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
_, _ = client.conn.Write(errReply.ToBytes())
}
} else {
client.args[client.receivedCount] = line
client.receivedCount++
}
// if sending finished
if client.receivedCount == client.expectedArgsCount {
client.uploading.Set(false) // finish sending progress
// send reply
result := h.db.Exec(client, client.args)
if result != nil {
_ = client.Write(result.ToBytes())
} else {
_ = client.Write(UnknownErrReplyBytes)
}
// finish reply
client.expectedArgsCount = 0
client.receivedCount = 0
client.args = nil
client.waitingReply.Done()
continue
}
}
if payload.Data == nil {
logger.Error("empty payload")
continue
}
r, ok := payload.Data.(*reply.MultiBulkReply)
if !ok {
logger.Error("require multi bulk reply")
continue
}
result := h.db.Exec(client, r.Args)
if result != nil {
_ = client.Write(result.ToBytes())
} else {
_ = client.Write(UnknownErrReplyBytes)
}
}
}