diff --git a/database/single_db.go b/database/single_db.go index 5e402a5..509039c 100644 --- a/database/single_db.go +++ b/database/single_db.go @@ -101,8 +101,7 @@ func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply { return Watch(db, c, cmdLine[1:]) } if c != nil && c.InMultiState() { - EnqueueCmd(c, cmdLine) - return protocol.MakeQueuedReply() + return EnqueueCmd(c, cmdLine) } return db.execNormalCommand(cmdLine) diff --git a/database/transaction.go b/database/transaction.go index ec82437..376de0c 100644 --- a/database/transaction.go +++ b/database/transaction.go @@ -51,14 +51,19 @@ func EnqueueCmd(conn redis.Connection, cmdLine [][]byte) redis.Reply { cmdName := strings.ToLower(string(cmdLine[0])) cmd, ok := cmdTable[cmdName] if !ok { - return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'") + err := protocol.MakeErrReply("ERR unknown command '" + cmdName + "'") + conn.AddTxError(err) + return err } if cmd.prepare == nil { - return protocol.MakeErrReply("ERR command '" + cmdName + "' cannot be used in MULTI") + err := protocol.MakeErrReply("ERR command '" + cmdName + "' cannot be used in MULTI") + conn.AddTxError(err) + return err } if !validateArity(cmd.arity, cmdLine) { - // difference with redis: we won't enqueue command line with wrong arity - return protocol.MakeArgNumErrReply(cmdName) + err := protocol.MakeArgNumErrReply(cmdName) + conn.AddTxError(err) + return err } conn.EnqueueCmd(cmdLine) return protocol.MakeQueuedReply() @@ -69,6 +74,9 @@ func execMulti(db *DB, conn redis.Connection) redis.Reply { return protocol.MakeErrReply("ERR EXEC without MULTI") } defer conn.SetMultiState(false) + if len(conn.GetTxErrors()) > 0 { + return protocol.MakeErrReply("EXECABORT Transaction discarded because of previous errors.") + } cmdLines := conn.GetQueuedCmdLine() return db.ExecMulti(conn, conn.GetWatching(), cmdLines) } diff --git a/database/transaction_test.go b/database/transaction_test.go index e263d2c..bb65527 100644 --- a/database/transaction_test.go +++ b/database/transaction_test.go @@ -31,6 +31,20 @@ func TestMulti(t *testing.T) { } } +func TestSyntaxErr(t *testing.T) { + conn := new(connection.FakeConn) + testServer.Exec(conn, utils.ToCmdLine("FLUSHALL")) + result := testServer.Exec(conn, utils.ToCmdLine("multi")) + asserts.AssertNotError(t, result) + result = testServer.Exec(conn, utils.ToCmdLine("set")) + asserts.AssertErrReply(t, result, "ERR wrong number of arguments for 'set' command") + testServer.Exec(conn, utils.ToCmdLine("get", "a")) + result = testServer.Exec(conn, utils.ToCmdLine("exec")) + asserts.AssertErrReply(t, result, "EXECABORT Transaction discarded because of previous errors.") + result = testServer.Exec(conn, utils.ToCmdLine("get", "a")) + asserts.AssertNotError(t, result) +} + func TestRollback(t *testing.T) { conn := new(connection.FakeConn) testServer.Exec(conn, utils.ToCmdLine("FLUSHALL")) diff --git a/datastruct/dict/simple.go b/datastruct/dict/simple.go index 8f24a3c..1ad9402 100644 --- a/datastruct/dict/simple.go +++ b/datastruct/dict/simple.go @@ -72,6 +72,7 @@ func (dict *SimpleDict) Keys() []string { i := 0 for k := range dict.m { result[i] = k + i++ } return result } diff --git a/datastruct/dict/simple_test.go b/datastruct/dict/simple_test.go index 9a52bfd..74f4a4a 100644 --- a/datastruct/dict/simple_test.go +++ b/datastruct/dict/simple_test.go @@ -2,18 +2,34 @@ package dict import ( "github.com/hdt3213/godis/lib/utils" + "sort" "testing" ) func TestSimpleDict_Keys(t *testing.T) { d := MakeSimple() size := 10 + var expectKeys []string for i := 0; i < size; i++ { - d.Put(utils.RandString(5), utils.RandString(5)) + str := utils.RandString(5) + d.Put(str, str) + expectKeys = append(expectKeys, str) } - if len(d.Keys()) != size { + sort.Slice(expectKeys, func(i, j int) bool { + return expectKeys[i] > expectKeys[j] + }) + keys := d.Keys() + if len(keys) != size { t.Errorf("expect %d keys, actual: %d", size, len(d.Keys())) } + sort.Slice(keys, func(i, j int) bool { + return keys[i] > keys[j] + }) + for i, k := range keys { + if k != expectKeys[i] { + t.Errorf("expect %s actual %s", expectKeys[i], k) + } + } } func TestSimpleDict_PutIfExists(t *testing.T) { diff --git a/interface/redis/conn.go b/interface/redis/conn.go index a1b0fc5..05ca534 100644 --- a/interface/redis/conn.go +++ b/interface/redis/conn.go @@ -19,6 +19,8 @@ type Connection interface { EnqueueCmd([][]byte) ClearQueuedCmds() GetWatching() map[string]uint32 + AddTxError(err error) + GetTxErrors() []error // used for multi database GetDBIndex() int diff --git a/lib/sync/wait/wait.go b/lib/sync/wait/wait.go index 1fb99a7..331f5ed 100644 --- a/lib/sync/wait/wait.go +++ b/lib/sync/wait/wait.go @@ -31,7 +31,7 @@ func (w *Wait) WaitWithTimeout(timeout time.Duration) bool { c := make(chan bool, 1) go func() { defer close(c) - w.wg.Wait() + w.Wait() c <- true }() select { diff --git a/redis/connection/conn.go b/redis/connection/conn.go index 6aab461..67d5a0b 100644 --- a/redis/connection/conn.go +++ b/redis/connection/conn.go @@ -35,6 +35,7 @@ type Connection struct { multiState bool queue [][][]byte watching map[string]uint32 + txErrors []error // selected db selectedDB int @@ -65,11 +66,9 @@ func (c *Connection) Write(b []byte) error { if len(b) == 0 { return nil } - c.mu.Lock() c.waitingReply.Add(1) defer func() { c.waitingReply.Done() - c.mu.Unlock() }() _, err := c.conn.Write(b) @@ -151,6 +150,16 @@ func (c *Connection) EnqueueCmd(cmdLine [][]byte) { c.queue = append(c.queue, cmdLine) } +// AddTxError stores syntax error within transaction +func (c *Connection) AddTxError(err error) { + c.txErrors = append(c.txErrors, err) +} + +// GetTxErrors returns syntax error within transaction +func (c *Connection) GetTxErrors() []error { + return c.txErrors +} + // ClearQueuedCmds clears queued commands of current transaction func (c *Connection) ClearQueuedCmds() { c.queue = nil