diff --git a/redis/parser/parser.go b/redis/parser/parser.go index 1df0e6b..424edcb 100644 --- a/redis/parser/parser.go +++ b/redis/parser/parser.go @@ -4,13 +4,13 @@ import ( "bufio" "bytes" "errors" + "fmt" "github.com/hdt3213/godis/interface/redis" "github.com/hdt3213/godis/lib/logger" "github.com/hdt3213/godis/redis/protocol" "io" "runtime/debug" "strconv" - "strings" ) // Payload stores redis.Reply or error @@ -59,234 +59,134 @@ func ParseOne(data []byte) (redis.Reply, error) { return payload.Data, payload.Err } -type readState struct { - readingMultiLine bool - expectedArgsCount int - msgType byte - args [][]byte - bulkLen int64 - readingRepl bool -} - -func (s *readState) finished() bool { - return s.expectedArgsCount > 0 && len(s.args) == s.expectedArgsCount -} - -func parse0(reader io.Reader, ch chan<- *Payload) { +func parse0(rawReader io.Reader, ch chan<- *Payload) { defer func() { if err := recover(); err != nil { logger.Error(err, string(debug.Stack())) - close(ch) } }() - - bufReader := bufio.NewReader(reader) - var state readState - var err error - var msg []byte + reader := bufio.NewReader(rawReader) for { - // read line - var ioErr bool - msg, ioErr, err = readLine(bufReader, &state) + line, err := reader.ReadBytes('\n') if err != nil { ch <- &Payload{Err: err} - if ioErr { // encounter io err, stop read + close(ch) + return + } + length := len(line) + if length <= 2 || line[length-2] != '\r' { + protocolError(ch, line) + continue + } + line = bytes.TrimSuffix(line, []byte{'\r', '\n'}) + switch line[0] { + case '+': + ch <- &Payload{ + Data: protocol.MakeStatusReply(string(line[1:])), + } + case '-': + ch <- &Payload{ + Data: protocol.MakeErrReply(string(line[1:])), + } + case ':': + value, err := strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + protocolError(ch, line) + continue + } + ch <- &Payload{ + Data: protocol.MakeIntReply(value), + } + case '$': + err = parseBulkString(line, reader, ch) + if err != nil { + ch <- &Payload{Err: err} close(ch) return } - // protocol err, reset read state - state = readState{} - continue - } - - // parse line - if !state.readingMultiLine { - // receive new response - if msg[0] == '*' { - // multi bulk protocol - err = parseMultiBulkHeader(msg, &state) - if err != nil { - ch <- &Payload{ - Err: errors.New("protocol error: " + string(msg)), - } - state = readState{} // reset state - continue - } - if state.expectedArgsCount == 0 { - ch <- &Payload{ - Data: &protocol.EmptyMultiBulkReply{}, - } - state = readState{} // reset state - continue - } - } else if msg[0] == '$' { // bulk protocol - err = parseBulkHeader(msg, &state) - if err != nil { - ch <- &Payload{ - Err: errors.New("protocol error: " + string(msg)), - } - state = readState{} // reset state - continue - } - if state.bulkLen == -1 { // null bulk protocol - ch <- &Payload{ - Data: &protocol.NullBulkReply{}, - } - state = readState{} // reset state - continue - } - } else { - // single line protocol - result, err := parseSingleLineReply(msg) - ch <- &Payload{ - Data: result, - Err: err, - } - state = readState{} // reset state - continue - } - } else { - // receive following bulk protocol - err = readBody(msg, &state) + case '*': + err = parseArray(line, reader, ch) if err != nil { - ch <- &Payload{ - Err: errors.New("protocol error: " + string(msg)), - } - state = readState{} // reset state - continue + ch <- &Payload{Err: err} + close(ch) + return } - // if sending finished - if state.finished() { - var result redis.Reply - if state.msgType == '*' { - result = protocol.MakeMultiBulkReply(state.args) - } else if state.msgType == '$' { - result = protocol.MakeBulkReply(state.args[0]) - } - ch <- &Payload{ - Data: result, - Err: err, - } - state = readState{} + default: + args := bytes.Split(line, []byte{' '}) + ch <- &Payload{ + Data: protocol.MakeMultiBulkReply(args), } } } } -func readLine(bufReader *bufio.Reader, state *readState) ([]byte, bool, error) { - var msg []byte - var err error - if state.bulkLen == 0 { // read normal line - msg, err = bufReader.ReadBytes('\n') - if err != nil { - return nil, true, err +func parseBulkString(header []byte, reader *bufio.Reader, ch chan<- *Payload) error { + strLen, err := strconv.ParseInt(string(header[1:]), 10, 64) + if err != nil || strLen < -1 { + protocolError(ch, header) + return nil + } else if strLen == -1 { + ch <- &Payload{ + Data: protocol.MakeNullBulkReply(), } - if len(msg) == 0 || msg[len(msg)-2] != '\r' { - return nil, false, errors.New("protocol error: " + string(msg)) - } - } else { // read bulk line (binary safe) - // there is CRLF between BulkReply in normal stream - // but there is no CRLF between RDB and following AOF - bulkLen := state.bulkLen + 2 - if state.readingRepl { - bulkLen -= 2 - } - msg = make([]byte, bulkLen) - _, err = io.ReadFull(bufReader, msg) - if err != nil { - return nil, true, err - } - //if len(msg) == 0 || - // msg[len(msg)-2] != '\r' || - // msg[len(msg)-1] != '\n' { - // return nil, false, errors.New("protocol error: " + string(msg)) - //} - state.bulkLen = 0 + return nil } - return msg, false, nil -} - -func parseMultiBulkHeader(msg []byte, state *readState) error { - var err error - var expectedLine uint64 - expectedLine, err = strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32) + body := make([]byte, strLen+2) + _, err = io.ReadFull(reader, body) if err != nil { - return errors.New("protocol error: " + string(msg)) + return err } - if expectedLine == 0 { - state.expectedArgsCount = 0 - return nil - } else if expectedLine > 0 { - // first line of multi bulk protocol - state.msgType = msg[0] - state.readingMultiLine = true - state.expectedArgsCount = int(expectedLine) - state.args = make([][]byte, 0, expectedLine) - return nil - } - return errors.New("protocol error: " + string(msg)) -} - -func parseBulkHeader(msg []byte, state *readState) error { - var err error - state.bulkLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64) - if err != nil { - return errors.New("protocol error: " + string(msg)) - } - if state.bulkLen == -1 { // null bulk - return nil - } else if state.bulkLen >= 0 { - state.msgType = msg[0] - state.readingMultiLine = true - state.expectedArgsCount = 1 - state.args = make([][]byte, 0, 1) - return nil - } - return errors.New("protocol error: " + string(msg)) -} - -func parseSingleLineReply(msg []byte) (redis.Reply, error) { - str := strings.TrimSuffix(string(msg), "\r\n") - var result redis.Reply - switch msg[0] { - case '+': // status protocol - result = protocol.MakeStatusReply(str[1:]) - case '-': // err protocol - result = protocol.MakeErrReply(str[1:]) - case ':': // int protocol - val, err := strconv.ParseInt(str[1:], 10, 64) - if err != nil { - return nil, errors.New("protocol error: " + string(msg)) - } - result = protocol.MakeIntReply(val) - default: - // parse as text protocol - strs := strings.Split(str, " ") - args := make([][]byte, len(strs)) - for i, s := range strs { - args[i] = []byte(s) - } - result = protocol.MakeMultiBulkReply(args) - } - return result, nil -} - -// read the non-first lines of multi bulk protocol or bulk protocol -func readBody(msg []byte, state *readState) error { - line := msg[0 : len(msg)-2] - var err error - if len(line) > 0 && line[0] == '$' { - // bulk protocol - state.bulkLen, err = strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - return errors.New("protocol error: " + string(msg)) - } - if state.bulkLen < 0 { // null bulk in multi bulks - state.args = append(state.args, []byte{}) - state.bulkLen = 0 - } - } else { - state.args = append(state.args, line) + ch <- &Payload{ + Data: protocol.MakeBulkReply(body[:len(body)-2]), } return nil } + +func parseArray(header []byte, reader *bufio.Reader, ch chan<- *Payload) error { + nStrs, err := strconv.ParseInt(string(header[1:]), 10, 64) + if err != nil || nStrs < 0 { + protocolError(ch, header) + return nil + } else if nStrs == 0 { + ch <- &Payload{ + Data: protocol.MakeEmptyMultiBulkReply(), + } + return nil + } + lines := make([][]byte, 0, nStrs) + for i := int64(0); i < nStrs; i++ { + var line []byte + line, err = reader.ReadBytes('\n') + if err != nil { + return err + } + length := len(line) + if length < 4 || line[length-2] != '\r' || line[0] != '$' { + protocolError(ch, line) + break + } + strLen, err := strconv.ParseInt(string(line[1:length-2]), 10, 64) + if err != nil || strLen < -1 { + protocolError(ch, header) + break + } else if strLen == -1 { + lines = append(lines, []byte{}) + } else { + body := make([]byte, strLen+2) + _, err := io.ReadFull(reader, body) + if err != nil { + return err + } + lines = append(lines, body[:len(body)-2]) + } + } + ch <- &Payload{ + Data: protocol.MakeMultiBulkReply(lines), + } + return nil +} + +func protocolError(ch chan<- *Payload, msg []byte) { + err := errors.New(fmt.Sprintf("Protocol error: %s", string(msg))) + ch <- &Payload{Err: err} +}