tiny fixes

This commit is contained in:
finley
2022-08-20 22:28:06 +08:00
parent 23a6f95c18
commit 9da335811f
8 changed files with 60 additions and 11 deletions

View File

@@ -101,8 +101,7 @@ func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply {
return Watch(db, c, cmdLine[1:]) return Watch(db, c, cmdLine[1:])
} }
if c != nil && c.InMultiState() { if c != nil && c.InMultiState() {
EnqueueCmd(c, cmdLine) return EnqueueCmd(c, cmdLine)
return protocol.MakeQueuedReply()
} }
return db.execNormalCommand(cmdLine) return db.execNormalCommand(cmdLine)

View File

@@ -51,14 +51,19 @@ func EnqueueCmd(conn redis.Connection, cmdLine [][]byte) redis.Reply {
cmdName := strings.ToLower(string(cmdLine[0])) cmdName := strings.ToLower(string(cmdLine[0]))
cmd, ok := cmdTable[cmdName] cmd, ok := cmdTable[cmdName]
if !ok { if !ok {
return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'") err := protocol.MakeErrReply("ERR unknown command '" + cmdName + "'")
conn.AddTxError(err)
return err
} }
if cmd.prepare == nil { if cmd.prepare == nil {
return protocol.MakeErrReply("ERR command '" + cmdName + "' cannot be used in MULTI") err := protocol.MakeErrReply("ERR command '" + cmdName + "' cannot be used in MULTI")
conn.AddTxError(err)
return err
} }
if !validateArity(cmd.arity, cmdLine) { if !validateArity(cmd.arity, cmdLine) {
// difference with redis: we won't enqueue command line with wrong arity err := protocol.MakeArgNumErrReply(cmdName)
return protocol.MakeArgNumErrReply(cmdName) conn.AddTxError(err)
return err
} }
conn.EnqueueCmd(cmdLine) conn.EnqueueCmd(cmdLine)
return protocol.MakeQueuedReply() return protocol.MakeQueuedReply()
@@ -69,6 +74,9 @@ func execMulti(db *DB, conn redis.Connection) redis.Reply {
return protocol.MakeErrReply("ERR EXEC without MULTI") return protocol.MakeErrReply("ERR EXEC without MULTI")
} }
defer conn.SetMultiState(false) defer conn.SetMultiState(false)
if len(conn.GetTxErrors()) > 0 {
return protocol.MakeErrReply("EXECABORT Transaction discarded because of previous errors.")
}
cmdLines := conn.GetQueuedCmdLine() cmdLines := conn.GetQueuedCmdLine()
return db.ExecMulti(conn, conn.GetWatching(), cmdLines) return db.ExecMulti(conn, conn.GetWatching(), cmdLines)
} }

View File

@@ -31,6 +31,20 @@ func TestMulti(t *testing.T) {
} }
} }
func TestSyntaxErr(t *testing.T) {
conn := new(connection.FakeConn)
testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))
result := testServer.Exec(conn, utils.ToCmdLine("multi"))
asserts.AssertNotError(t, result)
result = testServer.Exec(conn, utils.ToCmdLine("set"))
asserts.AssertErrReply(t, result, "ERR wrong number of arguments for 'set' command")
testServer.Exec(conn, utils.ToCmdLine("get", "a"))
result = testServer.Exec(conn, utils.ToCmdLine("exec"))
asserts.AssertErrReply(t, result, "EXECABORT Transaction discarded because of previous errors.")
result = testServer.Exec(conn, utils.ToCmdLine("get", "a"))
asserts.AssertNotError(t, result)
}
func TestRollback(t *testing.T) { func TestRollback(t *testing.T) {
conn := new(connection.FakeConn) conn := new(connection.FakeConn)
testServer.Exec(conn, utils.ToCmdLine("FLUSHALL")) testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))

View File

@@ -72,6 +72,7 @@ func (dict *SimpleDict) Keys() []string {
i := 0 i := 0
for k := range dict.m { for k := range dict.m {
result[i] = k result[i] = k
i++
} }
return result return result
} }

View File

@@ -2,18 +2,34 @@ package dict
import ( import (
"github.com/hdt3213/godis/lib/utils" "github.com/hdt3213/godis/lib/utils"
"sort"
"testing" "testing"
) )
func TestSimpleDict_Keys(t *testing.T) { func TestSimpleDict_Keys(t *testing.T) {
d := MakeSimple() d := MakeSimple()
size := 10 size := 10
var expectKeys []string
for i := 0; i < size; i++ { for i := 0; i < size; i++ {
d.Put(utils.RandString(5), utils.RandString(5)) str := utils.RandString(5)
d.Put(str, str)
expectKeys = append(expectKeys, str)
} }
if len(d.Keys()) != size { sort.Slice(expectKeys, func(i, j int) bool {
return expectKeys[i] > expectKeys[j]
})
keys := d.Keys()
if len(keys) != size {
t.Errorf("expect %d keys, actual: %d", size, len(d.Keys())) t.Errorf("expect %d keys, actual: %d", size, len(d.Keys()))
} }
sort.Slice(keys, func(i, j int) bool {
return keys[i] > keys[j]
})
for i, k := range keys {
if k != expectKeys[i] {
t.Errorf("expect %s actual %s", expectKeys[i], k)
}
}
} }
func TestSimpleDict_PutIfExists(t *testing.T) { func TestSimpleDict_PutIfExists(t *testing.T) {

View File

@@ -19,6 +19,8 @@ type Connection interface {
EnqueueCmd([][]byte) EnqueueCmd([][]byte)
ClearQueuedCmds() ClearQueuedCmds()
GetWatching() map[string]uint32 GetWatching() map[string]uint32
AddTxError(err error)
GetTxErrors() []error
// used for multi database // used for multi database
GetDBIndex() int GetDBIndex() int

View File

@@ -31,7 +31,7 @@ func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
c := make(chan bool, 1) c := make(chan bool, 1)
go func() { go func() {
defer close(c) defer close(c)
w.wg.Wait() w.Wait()
c <- true c <- true
}() }()
select { select {

View File

@@ -35,6 +35,7 @@ type Connection struct {
multiState bool multiState bool
queue [][][]byte queue [][][]byte
watching map[string]uint32 watching map[string]uint32
txErrors []error
// selected db // selected db
selectedDB int selectedDB int
@@ -65,11 +66,9 @@ func (c *Connection) Write(b []byte) error {
if len(b) == 0 { if len(b) == 0 {
return nil return nil
} }
c.mu.Lock()
c.waitingReply.Add(1) c.waitingReply.Add(1)
defer func() { defer func() {
c.waitingReply.Done() c.waitingReply.Done()
c.mu.Unlock()
}() }()
_, err := c.conn.Write(b) _, err := c.conn.Write(b)
@@ -151,6 +150,16 @@ func (c *Connection) EnqueueCmd(cmdLine [][]byte) {
c.queue = append(c.queue, cmdLine) c.queue = append(c.queue, cmdLine)
} }
// AddTxError stores syntax error within transaction
func (c *Connection) AddTxError(err error) {
c.txErrors = append(c.txErrors, err)
}
// GetTxErrors returns syntax error within transaction
func (c *Connection) GetTxErrors() []error {
return c.txErrors
}
// ClearQueuedCmds clears queued commands of current transaction // ClearQueuedCmds clears queued commands of current transaction
func (c *Connection) ClearQueuedCmds() { func (c *Connection) ClearQueuedCmds() {
c.queue = nil c.queue = nil