mirror of
https://github.com/HDT3213/godis.git
synced 2025-11-01 20:42:43 +08:00
refactor RESP parser
This commit is contained in:
111
src/db/aof.go
111
src/db/aof.go
@@ -1,7 +1,6 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"github.com/hdt3213/godis/src/config"
|
||||
"github.com/hdt3213/godis/src/datastruct/dict"
|
||||
List "github.com/hdt3213/godis/src/datastruct/list"
|
||||
@@ -9,6 +8,8 @@ import (
|
||||
"github.com/hdt3213/godis/src/datastruct/set"
|
||||
SortedSet "github.com/hdt3213/godis/src/datastruct/sortedset"
|
||||
"github.com/hdt3213/godis/src/lib/logger"
|
||||
"github.com/hdt3213/godis/src/lib/utils"
|
||||
"github.com/hdt3213/godis/src/redis/parser"
|
||||
"github.com/hdt3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@@ -72,6 +73,7 @@ func trim(msg []byte) string {
|
||||
return trimmed
|
||||
}
|
||||
|
||||
|
||||
// read aof file
|
||||
func (db *DB) loadAof(maxBytes int) {
|
||||
// delete aofChan to prevent write again
|
||||
@@ -91,96 +93,29 @@ func (db *DB) loadAof(maxBytes int) {
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
reader := bufio.NewReader(file)
|
||||
var fixedLen int64 = 0
|
||||
var expectedArgsCount uint32
|
||||
var receivedCount uint32
|
||||
var args [][]byte
|
||||
processing := false
|
||||
var msg []byte
|
||||
readBytes := 0
|
||||
for {
|
||||
if maxBytes != 0 && readBytes >= maxBytes {
|
||||
break
|
||||
reader := utils.NewLimitedReader(file, maxBytes)
|
||||
ch := parser.Parse(reader)
|
||||
for p := range ch {
|
||||
if p.Err != nil {
|
||||
if p.Err == io.EOF {
|
||||
break
|
||||
}
|
||||
logger.Error("parse error: " + p.Err.Error())
|
||||
continue
|
||||
}
|
||||
if fixedLen == 0 {
|
||||
msg, err = reader.ReadBytes('\n')
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if len(msg) == 0 {
|
||||
logger.Warn("invalid format: line should end with \\r\\n")
|
||||
return
|
||||
}
|
||||
readBytes += len(msg)
|
||||
} else {
|
||||
msg = make([]byte, fixedLen+2)
|
||||
n, err := io.ReadFull(reader, msg)
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
if len(msg) == 0 {
|
||||
logger.Warn("invalid multibulk length")
|
||||
return
|
||||
}
|
||||
fixedLen = 0
|
||||
readBytes += n
|
||||
if p.Data == nil {
|
||||
logger.Error("empty payload")
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
return
|
||||
r, ok := p.Data.(*reply.MultiBulkReply)
|
||||
if !ok {
|
||||
logger.Error("require multi bulk reply")
|
||||
continue
|
||||
}
|
||||
|
||||
if !processing {
|
||||
// new request
|
||||
if msg[0] == '*' {
|
||||
// bulk multi msg
|
||||
expectedLine, err := strconv.ParseUint(trim(msg[1:]), 10, 32)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
return
|
||||
}
|
||||
expectedArgsCount = uint32(expectedLine)
|
||||
receivedCount = 0
|
||||
processing = true
|
||||
args = make([][]byte, expectedLine)
|
||||
} else {
|
||||
logger.Warn("msg should start with '*'")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// receive following part of a request
|
||||
line := msg[0 : len(msg)-2]
|
||||
if line[0] == '$' {
|
||||
fixedLen, err = strconv.ParseInt(trim(line[1:]), 10, 64)
|
||||
if err != nil {
|
||||
logger.Warn(err)
|
||||
return
|
||||
}
|
||||
if fixedLen <= 0 {
|
||||
logger.Warn("invalid multibulk length")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
args[receivedCount] = line
|
||||
receivedCount++
|
||||
}
|
||||
|
||||
// if sending finished
|
||||
if receivedCount == expectedArgsCount {
|
||||
processing = false
|
||||
|
||||
cmd := strings.ToLower(string(args[0]))
|
||||
cmdFunc, ok := router[cmd]
|
||||
if ok {
|
||||
cmdFunc(db, args[1:])
|
||||
}
|
||||
|
||||
// finish
|
||||
expectedArgsCount = 0
|
||||
receivedCount = 0
|
||||
args = nil
|
||||
}
|
||||
cmd := strings.ToLower(string(r.Args[0]))
|
||||
cmdFunc, ok := router[cmd]
|
||||
if ok {
|
||||
cmdFunc(db, r.Args[1:])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,6 +37,12 @@ const (
|
||||
FATAL
|
||||
)
|
||||
|
||||
const flags = log.LstdFlags
|
||||
|
||||
func init() {
|
||||
logger = log.New(os.Stdout, DefaultPrefix, flags)
|
||||
}
|
||||
|
||||
func Setup(settings *Settings) {
|
||||
var err error
|
||||
dir := settings.Path
|
||||
@@ -51,7 +57,7 @@ func Setup(settings *Settings) {
|
||||
}
|
||||
|
||||
mw := io.MultiWriter(os.Stdout, logFile)
|
||||
logger = log.New(mw, DefaultPrefix, log.LstdFlags)
|
||||
logger = log.New(mw, DefaultPrefix, flags)
|
||||
}
|
||||
|
||||
func setPrefix(level Level) {
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package gob
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
)
|
||||
|
||||
func Marshal(obj interface{}) ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(obj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func UnMarshal(data []byte, obj interface{}) error {
|
||||
dec := gob.NewDecoder(bytes.NewBuffer(data))
|
||||
return dec.Decode(obj)
|
||||
}
|
||||
34
src/lib/utils/limited_reader.go
Normal file
34
src/lib/utils/limited_reader.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
type LimitedReader struct {
|
||||
src io.Reader
|
||||
n int
|
||||
limit int
|
||||
}
|
||||
|
||||
func NewLimitedReader(src io.Reader, limit int) *LimitedReader {
|
||||
return &LimitedReader{
|
||||
src: src,
|
||||
limit: limit,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *LimitedReader) Read(p []byte) (n int, err error) {
|
||||
if r.src == nil {
|
||||
return 0, errors.New("no data source")
|
||||
}
|
||||
if r.limit > 0 && r.n >= r.limit {
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err = r.src.Read(p)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
r.n += n
|
||||
return
|
||||
}
|
||||
@@ -1,31 +1,25 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/hdt3213/godis/src/interface/redis"
|
||||
"github.com/hdt3213/godis/src/lib/logger"
|
||||
"github.com/hdt3213/godis/src/lib/sync/wait"
|
||||
"github.com/hdt3213/godis/src/redis/parser"
|
||||
"github.com/hdt3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
conn net.Conn
|
||||
sendingReqs chan *Request // waiting sending
|
||||
pendingReqs chan *Request // wait to send
|
||||
waitingReqs chan *Request // waiting response
|
||||
ticker *time.Ticker
|
||||
addr string
|
||||
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
writing *sync.WaitGroup
|
||||
working *sync.WaitGroup // its counter presents unfinished requests(pending and waiting)
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
@@ -47,15 +41,12 @@ func MakeClient(addr string) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Client{
|
||||
addr: addr,
|
||||
conn: conn,
|
||||
sendingReqs: make(chan *Request, chanSize),
|
||||
pendingReqs: make(chan *Request, chanSize),
|
||||
waitingReqs: make(chan *Request, chanSize),
|
||||
ctx: ctx,
|
||||
cancelFunc: cancel,
|
||||
writing: &sync.WaitGroup{},
|
||||
working: &sync.WaitGroup{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -64,20 +55,22 @@ func (client *Client) Start() {
|
||||
go client.handleWrite()
|
||||
go func() {
|
||||
err := client.handleRead()
|
||||
logger.Warn(err)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}()
|
||||
go client.heartbeat()
|
||||
}
|
||||
|
||||
func (client *Client) Close() {
|
||||
client.ticker.Stop()
|
||||
// stop new request
|
||||
close(client.sendingReqs)
|
||||
close(client.pendingReqs)
|
||||
|
||||
// wait stop process
|
||||
client.writing.Wait()
|
||||
client.working.Wait()
|
||||
|
||||
// clean
|
||||
client.cancelFunc()
|
||||
_ = client.conn.Close()
|
||||
close(client.waitingReqs)
|
||||
}
|
||||
@@ -106,34 +99,17 @@ func (client *Client) handleConnectionError(err error) error {
|
||||
}
|
||||
|
||||
func (client *Client) heartbeat() {
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case <-client.ticker.C:
|
||||
client.sendingReqs <- &Request{
|
||||
args: [][]byte{[]byte("PING")},
|
||||
heartbeat: true,
|
||||
}
|
||||
case <-client.ctx.Done():
|
||||
break loop
|
||||
}
|
||||
for range client.ticker.C {
|
||||
client.doHeartbeat()
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) handleWrite() {
|
||||
loop:
|
||||
for {
|
||||
select {
|
||||
case req := <-client.sendingReqs:
|
||||
client.writing.Add(1)
|
||||
client.doRequest(req)
|
||||
case <-client.ctx.Done():
|
||||
break loop
|
||||
}
|
||||
for req := range client.pendingReqs {
|
||||
client.doRequest(req)
|
||||
}
|
||||
}
|
||||
|
||||
// todo: wait with timeout
|
||||
func (client *Client) Send(args [][]byte) redis.Reply {
|
||||
request := &Request{
|
||||
args: args,
|
||||
@@ -141,7 +117,9 @@ func (client *Client) Send(args [][]byte) redis.Reply {
|
||||
waiting: &wait.Wait{},
|
||||
}
|
||||
request.waiting.Add(1)
|
||||
client.sendingReqs <- request
|
||||
client.working.Add(1)
|
||||
defer client.working.Done()
|
||||
client.pendingReqs <- request
|
||||
timeout := request.waiting.WaitWithTimeout(maxWait)
|
||||
if timeout {
|
||||
return reply.MakeErrReply("server time out")
|
||||
@@ -152,8 +130,24 @@ func (client *Client) Send(args [][]byte) redis.Reply {
|
||||
return request.reply
|
||||
}
|
||||
|
||||
func (client *Client) doHeartbeat() {
|
||||
request := &Request{
|
||||
args: [][]byte{[]byte("PING")},
|
||||
heartbeat: true,
|
||||
}
|
||||
request.waiting.Add(1)
|
||||
client.working.Add(1)
|
||||
defer client.working.Done()
|
||||
client.pendingReqs <- request
|
||||
request.waiting.WaitWithTimeout(maxWait)
|
||||
}
|
||||
|
||||
func (client *Client) doRequest(req *Request) {
|
||||
bytes := reply.MakeMultiBulkReply(req.args).ToBytes()
|
||||
if req == nil || len(req.args) == 0 {
|
||||
return
|
||||
}
|
||||
re := reply.MakeMultiBulkReply(req.args)
|
||||
bytes := re.ToBytes()
|
||||
_, err := client.conn.Write(bytes)
|
||||
i := 0
|
||||
for err != nil && i < 3 {
|
||||
@@ -168,154 +162,34 @@ func (client *Client) doRequest(req *Request) {
|
||||
} else {
|
||||
req.err = err
|
||||
req.waiting.Done()
|
||||
client.writing.Done()
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) finishRequest(reply redis.Reply) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
debug.PrintStack()
|
||||
logger.Error(err)
|
||||
}
|
||||
}()
|
||||
request := <-client.waitingReqs
|
||||
if request == nil {
|
||||
return
|
||||
}
|
||||
request.reply = reply
|
||||
if request.waiting != nil {
|
||||
request.waiting.Done()
|
||||
}
|
||||
client.writing.Done()
|
||||
}
|
||||
|
||||
func (client *Client) handleRead() error {
|
||||
reader := bufio.NewReader(client.conn)
|
||||
downloading := false
|
||||
expectedArgsCount := 0
|
||||
receivedCount := 0
|
||||
msgType := byte(0) // first char of msg
|
||||
var args [][]byte
|
||||
var fixedLen int64 = 0
|
||||
var err error
|
||||
var msg []byte
|
||||
for {
|
||||
// read line
|
||||
if fixedLen == 0 { // read normal line
|
||||
msg, err = reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
|
||||
return errors.New("connection closed")
|
||||
}
|
||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else { // read bulk line (binary safe)
|
||||
msg = make([]byte, fixedLen+2)
|
||||
_, err = io.ReadFull(reader, msg)
|
||||
if err != nil {
|
||||
if err == io.EOF || err == io.ErrUnexpectedEOF {
|
||||
return errors.New("connection closed")
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(msg) == 0 ||
|
||||
msg[len(msg)-2] != '\r' ||
|
||||
msg[len(msg)-1] != '\n' {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
fixedLen = 0
|
||||
}
|
||||
|
||||
// parse line
|
||||
if !downloading {
|
||||
// receive new response
|
||||
if msg[0] == '*' { // multi bulk response
|
||||
// bulk multi msg
|
||||
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||
if err != nil {
|
||||
return errors.New("protocol error: " + err.Error())
|
||||
}
|
||||
if expectedLine == 0 {
|
||||
client.finishRequest(&reply.EmptyMultiBulkReply{})
|
||||
} else if expectedLine > 0 {
|
||||
msgType = msg[0]
|
||||
downloading = true
|
||||
expectedArgsCount = int(expectedLine)
|
||||
receivedCount = 0
|
||||
args = make([][]byte, expectedLine)
|
||||
} else {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else if msg[0] == '$' { // bulk response
|
||||
fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fixedLen == -1 { // null bulk
|
||||
client.finishRequest(&reply.NullBulkReply{})
|
||||
fixedLen = 0
|
||||
} else if fixedLen > 0 {
|
||||
msgType = msg[0]
|
||||
downloading = true
|
||||
expectedArgsCount = 1
|
||||
receivedCount = 0
|
||||
args = make([][]byte, 1)
|
||||
} else {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
} else { // single line response
|
||||
str := strings.TrimSuffix(string(msg), "\n")
|
||||
str = strings.TrimSuffix(str, "\r")
|
||||
var result redis.Reply
|
||||
switch msg[0] {
|
||||
case '+':
|
||||
result = reply.MakeStatusReply(str[1:])
|
||||
case '-':
|
||||
result = reply.MakeErrReply(str[1:])
|
||||
case ':':
|
||||
val, err := strconv.ParseInt(str[1:], 10, 64)
|
||||
if err != nil {
|
||||
return errors.New("protocol error")
|
||||
}
|
||||
result = reply.MakeIntReply(val)
|
||||
}
|
||||
client.finishRequest(result)
|
||||
}
|
||||
} else {
|
||||
// receive following part of a request
|
||||
line := msg[0 : len(msg)-2]
|
||||
if line[0] == '$' {
|
||||
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fixedLen <= 0 { // null bulk in multi bulks
|
||||
args[receivedCount] = []byte{}
|
||||
receivedCount++
|
||||
fixedLen = 0
|
||||
}
|
||||
} else {
|
||||
args[receivedCount] = line
|
||||
receivedCount++
|
||||
}
|
||||
|
||||
// if sending finished
|
||||
if receivedCount == expectedArgsCount {
|
||||
downloading = false // finish downloading progress
|
||||
|
||||
if msgType == '*' {
|
||||
reply := reply.MakeMultiBulkReply(args)
|
||||
client.finishRequest(reply)
|
||||
} else if msgType == '$' {
|
||||
reply := reply.MakeBulkReply(args[0])
|
||||
client.finishRequest(reply)
|
||||
}
|
||||
|
||||
// finish reply
|
||||
expectedArgsCount = 0
|
||||
receivedCount = 0
|
||||
args = nil
|
||||
msgType = byte(0)
|
||||
}
|
||||
ch := parser.Parse(client.conn)
|
||||
for payload := range ch {
|
||||
if payload.Err != nil {
|
||||
client.finishRequest(reply.MakeErrReply(payload.Err.Error()))
|
||||
continue
|
||||
}
|
||||
client.finishRequest(payload.Data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
262
src/redis/parser/parser.go
Normal file
262
src/redis/parser/parser.go
Normal file
@@ -0,0 +1,262 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"github.com/hdt3213/godis/src/interface/redis"
|
||||
"github.com/hdt3213/godis/src/lib/logger"
|
||||
"github.com/hdt3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Payload struct {
|
||||
Data redis.Reply
|
||||
Err error
|
||||
}
|
||||
|
||||
func Parse(reader io.Reader) <-chan *Payload {
|
||||
ch := make(chan *Payload)
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Error(debug.Stack())
|
||||
}
|
||||
}()
|
||||
parse0(reader, ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
type readState struct {
|
||||
downloading bool
|
||||
expectedArgsCount int
|
||||
receivedCount int
|
||||
msgType byte
|
||||
args [][]byte
|
||||
fixedLen int64
|
||||
}
|
||||
|
||||
func (s *readState) finished() bool {
|
||||
return s.expectedArgsCount > 0 && s.receivedCount == s.expectedArgsCount
|
||||
}
|
||||
|
||||
func parse0(reader io.Reader, ch chan<- *Payload) {
|
||||
bufReader := bufio.NewReader(reader)
|
||||
var state readState
|
||||
var err error
|
||||
var msg []byte
|
||||
for {
|
||||
// read line
|
||||
var ioErr bool
|
||||
msg, ioErr, err = readLine(bufReader, &state)
|
||||
if err != nil {
|
||||
if ioErr { // encounter io err, stop read
|
||||
ch <- &Payload{
|
||||
Err: err,
|
||||
}
|
||||
close(ch)
|
||||
return
|
||||
} else { // protocol err, reset read state
|
||||
ch <- &Payload{
|
||||
Err: err,
|
||||
}
|
||||
state = readState{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// parse line
|
||||
if !state.downloading {
|
||||
// receive new response
|
||||
if msg[0] == '*' {
|
||||
// multi bulk reply
|
||||
err = parseMultiBulkHeader(msg, &state)
|
||||
if err != nil {
|
||||
ch <- &Payload{
|
||||
Err: errors.New("protocol error: " + string(msg)),
|
||||
}
|
||||
state = readState{} // reset state
|
||||
continue
|
||||
}
|
||||
if state.expectedArgsCount == 0 {
|
||||
ch <- &Payload{
|
||||
Data: &reply.EmptyMultiBulkReply{},
|
||||
}
|
||||
state = readState{} // reset state
|
||||
continue
|
||||
}
|
||||
} else if msg[0] == '$' { // bulk reply
|
||||
err = parseBulkHeader(msg, &state)
|
||||
if err != nil {
|
||||
ch <- &Payload{
|
||||
Err: errors.New("protocol error: " + string(msg)),
|
||||
}
|
||||
state = readState{} // reset state
|
||||
continue
|
||||
}
|
||||
if state.fixedLen == -1 { // null bulk reply
|
||||
ch <- &Payload{
|
||||
Data: &reply.NullBulkReply{},
|
||||
}
|
||||
state = readState{} // reset state
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// single line reply
|
||||
result, err := parseSingleLineReply(msg)
|
||||
ch <- &Payload{
|
||||
Data: result,
|
||||
Err: err,
|
||||
}
|
||||
state = readState{} // reset state
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
// receive following bulk reply
|
||||
err = readBulkBody(msg, &state)
|
||||
if err != nil {
|
||||
ch <- &Payload{
|
||||
Err: errors.New("protocol error: " + string(msg)),
|
||||
}
|
||||
state = readState{} // reset state
|
||||
continue
|
||||
}
|
||||
// if sending finished
|
||||
if state.finished() {
|
||||
var result redis.Reply
|
||||
if state.msgType == '*' {
|
||||
result = reply.MakeMultiBulkReply(state.args)
|
||||
} else if state.msgType == '$' {
|
||||
result = reply.MakeBulkReply(state.args[0])
|
||||
}
|
||||
ch <- &Payload{
|
||||
Data: result,
|
||||
Err: err,
|
||||
}
|
||||
state = readState{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) {
|
||||
var msg []byte
|
||||
var err error
|
||||
if state.fixedLen == 0 { // read normal line
|
||||
msg, err = bufReader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||
return nil, false, errors.New("protocol error: " + string(msg))
|
||||
}
|
||||
} else { // read bulk line (binary safe)
|
||||
msg = make([]byte, state.fixedLen+2)
|
||||
_, err = io.ReadFull(bufReader, msg)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
if len(msg) == 0 ||
|
||||
msg[len(msg)-2] != '\r' ||
|
||||
msg[len(msg)-1] != '\n' {
|
||||
return nil, false, errors.New("protocol error: " + string(msg))
|
||||
}
|
||||
state.fixedLen = 0
|
||||
}
|
||||
return msg, false, nil
|
||||
}
|
||||
|
||||
func parseMultiBulkHeader(msg []byte, state *readState) error {
|
||||
var err error
|
||||
var expectedLine uint64
|
||||
expectedLine, err = strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||
if err != nil {
|
||||
return errors.New("protocol error: " + string(msg))
|
||||
}
|
||||
if expectedLine == 0 {
|
||||
state.expectedArgsCount = 0
|
||||
return nil
|
||||
} else if expectedLine > 0 {
|
||||
// first line of multi bulk reply
|
||||
state.msgType = msg[0]
|
||||
state.downloading = true
|
||||
state.expectedArgsCount = int(expectedLine)
|
||||
state.receivedCount = 0
|
||||
state.args = make([][]byte, expectedLine)
|
||||
return nil
|
||||
} else {
|
||||
return errors.New("protocol error: " + string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func parseBulkHeader(msg []byte, state *readState) error {
|
||||
var err error
|
||||
state.fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
|
||||
if err != nil {
|
||||
return errors.New("protocol error: " + string(msg))
|
||||
}
|
||||
if state.fixedLen == -1 { // null bulk
|
||||
return nil
|
||||
} else if state.fixedLen > 0 {
|
||||
state.msgType = msg[0]
|
||||
state.downloading = true
|
||||
state.expectedArgsCount = 1
|
||||
state.receivedCount = 0
|
||||
state.args = make([][]byte, 1)
|
||||
return nil
|
||||
} else {
|
||||
return errors.New("protocol error: " + string(msg))
|
||||
}
|
||||
}
|
||||
|
||||
func parseSingleLineReply(msg []byte) (redis.Reply, error) {
|
||||
str := strings.TrimSuffix(string(msg), "\n")
|
||||
str = strings.TrimSuffix(str, "\r")
|
||||
var result redis.Reply
|
||||
switch msg[0] {
|
||||
case '+': // status reply
|
||||
result = reply.MakeStatusReply(str[1:])
|
||||
case '-': // err reply
|
||||
result = reply.MakeErrReply(str[1:])
|
||||
case ':': // int reply
|
||||
val, err := strconv.ParseInt(str[1:], 10, 64)
|
||||
if err != nil {
|
||||
return nil, errors.New("protocol error: " + string(msg))
|
||||
}
|
||||
result = reply.MakeIntReply(val)
|
||||
default:
|
||||
// parse as text protocol
|
||||
strs := strings.Split(str, " ")
|
||||
args := make([][]byte, len(strs))
|
||||
for i, s := range strs {
|
||||
args[i] = []byte(s)
|
||||
}
|
||||
result = reply.MakeMultiBulkReply(args)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// read the non-first lines of multi bulk reply or bulk reply
|
||||
func readBulkBody(msg []byte, state *readState) error {
|
||||
line := msg[0 : len(msg)-2]
|
||||
var err error
|
||||
if line[0] == '$' {
|
||||
// bulk reply
|
||||
state.fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||
if err != nil {
|
||||
return errors.New("protocol error: " + string(msg))
|
||||
}
|
||||
if state.fixedLen <= 0 { // null bulk in multi bulks
|
||||
state.args[state.receivedCount] = []byte{}
|
||||
state.receivedCount++
|
||||
state.fixedLen = 0
|
||||
}
|
||||
} else {
|
||||
state.args[state.receivedCount] = line
|
||||
state.receivedCount++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
56
src/redis/parser/parser_test.go
Normal file
56
src/redis/parser/parser_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/hdt3213/godis/src/datastruct/utils"
|
||||
"github.com/hdt3213/godis/src/interface/redis"
|
||||
"github.com/hdt3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
replies := []redis.Reply{
|
||||
reply.MakeIntReply(1),
|
||||
reply.MakeStatusReply("OK"),
|
||||
reply.MakeErrReply("ERR unknown"),
|
||||
reply.MakeBulkReply([]byte("a\r\nb")), // test binary safe
|
||||
reply.MakeNullBulkReply(),
|
||||
reply.MakeMultiBulkReply([][]byte{
|
||||
[]byte("a"),
|
||||
[]byte("\r\n"),
|
||||
}),
|
||||
reply.MakeEmptyMultiBulkReply(),
|
||||
}
|
||||
reqs := bytes.Buffer{}
|
||||
for _, re := range replies {
|
||||
reqs.Write(re.ToBytes())
|
||||
}
|
||||
reqs.Write([]byte("set a a" + reply.CRLF)) // test text protocol
|
||||
expected := make([]redis.Reply, len(replies))
|
||||
copy(expected, replies)
|
||||
expected = append(expected, reply.MakeMultiBulkReply([][]byte{
|
||||
[]byte("set"), []byte("a"), []byte("a"),
|
||||
}))
|
||||
|
||||
ch := Parse(bytes.NewReader(reqs.Bytes()))
|
||||
i := 0
|
||||
for payload := range ch {
|
||||
if payload.Err != nil {
|
||||
if payload.Err == io.EOF {
|
||||
return
|
||||
}
|
||||
t.Error(payload.Err)
|
||||
return
|
||||
}
|
||||
if payload.Data == nil {
|
||||
t.Error("empty data")
|
||||
return
|
||||
}
|
||||
exp := expected[i]
|
||||
i++
|
||||
if !utils.BytesEquals(exp.ToBytes(), payload.Data.ToBytes()) {
|
||||
t.Error("parse failed: " + string(exp.ToBytes()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -24,6 +24,10 @@ func (r *NullBulkReply) ToBytes() []byte {
|
||||
return nullBulkBytes
|
||||
}
|
||||
|
||||
func MakeNullBulkReply() *NullBulkReply {
|
||||
return &NullBulkReply{}
|
||||
}
|
||||
|
||||
var emptyMultiBulkBytes = []byte("*0\r\n")
|
||||
|
||||
type EmptyMultiBulkReply struct{}
|
||||
@@ -32,6 +36,10 @@ func (r *EmptyMultiBulkReply) ToBytes() []byte {
|
||||
return emptyMultiBulkBytes
|
||||
}
|
||||
|
||||
func MakeEmptyMultiBulkReply() *EmptyMultiBulkReply {
|
||||
return &EmptyMultiBulkReply{}
|
||||
}
|
||||
|
||||
// reply nothing, for commands like subscribe
|
||||
type NoReply struct{}
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/hdt3213/godis/src/lib/sync/atomic"
|
||||
"github.com/hdt3213/godis/src/lib/sync/wait"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -15,15 +14,6 @@ type Client struct {
|
||||
// waiting util reply finished
|
||||
waitingReply wait.Wait
|
||||
|
||||
// is sending request in progress
|
||||
uploading atomic.AtomicBool
|
||||
// multi bulk msg lineCount - 1(first line)
|
||||
expectedArgsCount uint32
|
||||
// sent line count, exclude first line
|
||||
receivedCount uint32
|
||||
// sent lines, exclude first line
|
||||
args [][]byte
|
||||
|
||||
// lock while server sending response
|
||||
mu sync.Mutex
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ package server
|
||||
*/
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"github.com/hdt3213/godis/src/cluster"
|
||||
"github.com/hdt3213/godis/src/config"
|
||||
@@ -13,10 +12,10 @@ import (
|
||||
"github.com/hdt3213/godis/src/interface/db"
|
||||
"github.com/hdt3213/godis/src/lib/logger"
|
||||
"github.com/hdt3213/godis/src/lib/sync/atomic"
|
||||
"github.com/hdt3213/godis/src/redis/parser"
|
||||
"github.com/hdt3213/godis/src/redis/reply"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
@@ -59,126 +58,43 @@ func (h *Handler) Handle(ctx context.Context, conn net.Conn) {
|
||||
client := MakeClient(conn)
|
||||
h.activeConn.Store(client, 1)
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
var fixedLen int64 = 0
|
||||
var err error
|
||||
var msg []byte
|
||||
for {
|
||||
if fixedLen == 0 {
|
||||
msg, err = reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF ||
|
||||
err == io.ErrUnexpectedEOF ||
|
||||
strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
|
||||
// after client close
|
||||
ch := parser.Parse(conn)
|
||||
for payload := range ch {
|
||||
if payload.Err != nil {
|
||||
if payload.Err == io.EOF ||
|
||||
payload.Err == io.ErrUnexpectedEOF ||
|
||||
strings.Contains(payload.Err.Error(), "use of closed network connection") {
|
||||
// connection closed
|
||||
h.closeClient(client)
|
||||
return // io error, disconnect with client
|
||||
}
|
||||
if len(msg) == 0 || msg[len(msg)-2] != '\r' {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
} else {
|
||||
msg = make([]byte, fixedLen+2)
|
||||
_, err = io.ReadFull(reader, msg)
|
||||
if err != nil {
|
||||
if err == io.EOF ||
|
||||
err == io.ErrUnexpectedEOF ||
|
||||
strings.Contains(err.Error(), "use of closed network connection") {
|
||||
logger.Info("connection close")
|
||||
} else {
|
||||
logger.Warn(err)
|
||||
}
|
||||
|
||||
// after client close
|
||||
h.closeClient(client)
|
||||
return // io error, disconnect with client
|
||||
}
|
||||
if len(msg) == 0 ||
|
||||
msg[len(msg)-2] != '\r' ||
|
||||
msg[len(msg)-1] != '\n' {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
fixedLen = 0
|
||||
}
|
||||
|
||||
if !client.uploading.Get() {
|
||||
// new request
|
||||
if msg[0] == '*' {
|
||||
// bulk multi msg
|
||||
expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
|
||||
if err != nil {
|
||||
_, _ = client.conn.Write(UnknownErrReplyBytes)
|
||||
continue
|
||||
}
|
||||
client.waitingReply.Add(1)
|
||||
client.uploading.Set(true)
|
||||
client.expectedArgsCount = uint32(expectedLine)
|
||||
client.receivedCount = 0
|
||||
client.args = make([][]byte, expectedLine)
|
||||
logger.Info("connection closed: " + client.conn.RemoteAddr().String())
|
||||
return
|
||||
} else {
|
||||
// text protocol
|
||||
// remove \r or \n or \r\n in the end of line
|
||||
str := strings.TrimSuffix(string(msg), "\n")
|
||||
str = strings.TrimSuffix(str, "\r")
|
||||
strs := strings.Split(str, " ")
|
||||
args := make([][]byte, len(strs))
|
||||
for i, s := range strs {
|
||||
args[i] = []byte(s)
|
||||
}
|
||||
|
||||
// send reply
|
||||
result := h.db.Exec(client, args)
|
||||
if result != nil {
|
||||
_ = client.Write(result.ToBytes())
|
||||
} else {
|
||||
_ = client.Write(UnknownErrReplyBytes)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// receive following part of a request
|
||||
line := msg[0 : len(msg)-2]
|
||||
if line[0] == '$' {
|
||||
fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
|
||||
// protocol err
|
||||
errReply := reply.MakeErrReply(payload.Err.Error())
|
||||
err := client.Write(errReply.ToBytes())
|
||||
if err != nil {
|
||||
errReply := &reply.ProtocolErrReply{Msg: err.Error()}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
h.closeClient(client)
|
||||
logger.Info("connection closed: " + client.conn.RemoteAddr().String())
|
||||
return
|
||||
}
|
||||
if fixedLen <= 0 {
|
||||
errReply := &reply.ProtocolErrReply{Msg: "invalid multibulk length"}
|
||||
_, _ = client.conn.Write(errReply.ToBytes())
|
||||
}
|
||||
} else {
|
||||
client.args[client.receivedCount] = line
|
||||
client.receivedCount++
|
||||
}
|
||||
|
||||
// if sending finished
|
||||
if client.receivedCount == client.expectedArgsCount {
|
||||
client.uploading.Set(false) // finish sending progress
|
||||
|
||||
// send reply
|
||||
result := h.db.Exec(client, client.args)
|
||||
if result != nil {
|
||||
_ = client.Write(result.ToBytes())
|
||||
} else {
|
||||
_ = client.Write(UnknownErrReplyBytes)
|
||||
}
|
||||
|
||||
// finish reply
|
||||
client.expectedArgsCount = 0
|
||||
client.receivedCount = 0
|
||||
client.args = nil
|
||||
client.waitingReply.Done()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if payload.Data == nil {
|
||||
logger.Error("empty payload")
|
||||
continue
|
||||
}
|
||||
r, ok := payload.Data.(*reply.MultiBulkReply)
|
||||
if !ok {
|
||||
logger.Error("require multi bulk reply")
|
||||
continue
|
||||
}
|
||||
result := h.db.Exec(client, r.Args)
|
||||
if result != nil {
|
||||
_ = client.Write(result.ToBytes())
|
||||
} else {
|
||||
_ = client.Write(UnknownErrReplyBytes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user