diff --git a/database/string_test.go b/database/string_test.go index d1a379d..10bd114 100644 --- a/database/string_test.go +++ b/database/string_test.go @@ -25,6 +25,20 @@ func TestSet2(t *testing.T) { } } +func TestSetEmpty(t *testing.T) { + key := utils.RandString(10) + testDB.Exec(nil, utils.ToCmdLine("SET", key, "")) + actual := testDB.Exec(nil, utils.ToCmdLine("GET", key)) + bulkReply, ok := actual.(*protocol.BulkReply) + if !ok { + t.Errorf("expected bulk protocol, actually %s", actual.ToBytes()) + return + } + if !(bulkReply.Arg != nil && len(bulkReply.Arg) == 0) { + t.Error("illegal empty string") + } +} + func TestSet(t *testing.T) { testDB.Flush() key := utils.RandString(10) diff --git a/redis/parser/parser.go b/redis/parser/parser.go index 2ac6ed7..e22286c 100644 --- a/redis/parser/parser.go +++ b/redis/parser/parser.go @@ -240,7 +240,7 @@ func parseBulkHeader(msg []byte, state *readState) error { } if state.bulkLen == -1 { // null bulk return nil - } else if state.bulkLen > 0 { + } else if state.bulkLen >= 0 { state.msgType = msg[0] state.readingMultiLine = true state.expectedArgsCount = 1 @@ -281,13 +281,13 @@ func parseSingleLineReply(msg []byte) (redis.Reply, error) { func readBody(msg []byte, state *readState) error { line := msg[0 : len(msg)-2] var err error - if line[0] == '$' { + if len(line) > 0 && line[0] == '$' { // bulk protocol state.bulkLen, err = strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { return errors.New("protocol error: " + string(msg)) } - if state.bulkLen <= 0 { // null bulk in multi bulks + if state.bulkLen < 0 { // null bulk in multi bulks state.args = append(state.args, []byte{}) state.bulkLen = 0 } diff --git a/redis/protocol/reply.go b/redis/protocol/reply.go index ed4bbbe..98cadb1 100644 --- a/redis/protocol/reply.go +++ b/redis/protocol/reply.go @@ -28,7 +28,7 @@ func MakeBulkReply(arg []byte) *BulkReply { // ToBytes marshal redis.Reply func (r *BulkReply) ToBytes() []byte { - if len(r.Arg) == 0 { + if r.Arg == nil { return nullBulkBytes } return []byte("$" + strconv.Itoa(len(r.Arg)) + CRLF + string(r.Arg) + CRLF)