mirror of
https://github.com/HDT3213/godis.git
synced 2025-10-05 16:57:06 +08:00
tiny fixes
This commit is contained in:
@@ -101,8 +101,7 @@ func (db *DB) Exec(c redis.Connection, cmdLine [][]byte) redis.Reply {
|
|||||||
return Watch(db, c, cmdLine[1:])
|
return Watch(db, c, cmdLine[1:])
|
||||||
}
|
}
|
||||||
if c != nil && c.InMultiState() {
|
if c != nil && c.InMultiState() {
|
||||||
EnqueueCmd(c, cmdLine)
|
return EnqueueCmd(c, cmdLine)
|
||||||
return protocol.MakeQueuedReply()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return db.execNormalCommand(cmdLine)
|
return db.execNormalCommand(cmdLine)
|
||||||
|
@@ -51,14 +51,19 @@ func EnqueueCmd(conn redis.Connection, cmdLine [][]byte) redis.Reply {
|
|||||||
cmdName := strings.ToLower(string(cmdLine[0]))
|
cmdName := strings.ToLower(string(cmdLine[0]))
|
||||||
cmd, ok := cmdTable[cmdName]
|
cmd, ok := cmdTable[cmdName]
|
||||||
if !ok {
|
if !ok {
|
||||||
return protocol.MakeErrReply("ERR unknown command '" + cmdName + "'")
|
err := protocol.MakeErrReply("ERR unknown command '" + cmdName + "'")
|
||||||
|
conn.AddTxError(err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if cmd.prepare == nil {
|
if cmd.prepare == nil {
|
||||||
return protocol.MakeErrReply("ERR command '" + cmdName + "' cannot be used in MULTI")
|
err := protocol.MakeErrReply("ERR command '" + cmdName + "' cannot be used in MULTI")
|
||||||
|
conn.AddTxError(err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if !validateArity(cmd.arity, cmdLine) {
|
if !validateArity(cmd.arity, cmdLine) {
|
||||||
// difference with redis: we won't enqueue command line with wrong arity
|
err := protocol.MakeArgNumErrReply(cmdName)
|
||||||
return protocol.MakeArgNumErrReply(cmdName)
|
conn.AddTxError(err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
conn.EnqueueCmd(cmdLine)
|
conn.EnqueueCmd(cmdLine)
|
||||||
return protocol.MakeQueuedReply()
|
return protocol.MakeQueuedReply()
|
||||||
@@ -69,6 +74,9 @@ func execMulti(db *DB, conn redis.Connection) redis.Reply {
|
|||||||
return protocol.MakeErrReply("ERR EXEC without MULTI")
|
return protocol.MakeErrReply("ERR EXEC without MULTI")
|
||||||
}
|
}
|
||||||
defer conn.SetMultiState(false)
|
defer conn.SetMultiState(false)
|
||||||
|
if len(conn.GetTxErrors()) > 0 {
|
||||||
|
return protocol.MakeErrReply("EXECABORT Transaction discarded because of previous errors.")
|
||||||
|
}
|
||||||
cmdLines := conn.GetQueuedCmdLine()
|
cmdLines := conn.GetQueuedCmdLine()
|
||||||
return db.ExecMulti(conn, conn.GetWatching(), cmdLines)
|
return db.ExecMulti(conn, conn.GetWatching(), cmdLines)
|
||||||
}
|
}
|
||||||
|
@@ -31,6 +31,20 @@ func TestMulti(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSyntaxErr(t *testing.T) {
|
||||||
|
conn := new(connection.FakeConn)
|
||||||
|
testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))
|
||||||
|
result := testServer.Exec(conn, utils.ToCmdLine("multi"))
|
||||||
|
asserts.AssertNotError(t, result)
|
||||||
|
result = testServer.Exec(conn, utils.ToCmdLine("set"))
|
||||||
|
asserts.AssertErrReply(t, result, "ERR wrong number of arguments for 'set' command")
|
||||||
|
testServer.Exec(conn, utils.ToCmdLine("get", "a"))
|
||||||
|
result = testServer.Exec(conn, utils.ToCmdLine("exec"))
|
||||||
|
asserts.AssertErrReply(t, result, "EXECABORT Transaction discarded because of previous errors.")
|
||||||
|
result = testServer.Exec(conn, utils.ToCmdLine("get", "a"))
|
||||||
|
asserts.AssertNotError(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
func TestRollback(t *testing.T) {
|
func TestRollback(t *testing.T) {
|
||||||
conn := new(connection.FakeConn)
|
conn := new(connection.FakeConn)
|
||||||
testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))
|
testServer.Exec(conn, utils.ToCmdLine("FLUSHALL"))
|
||||||
|
@@ -72,6 +72,7 @@ func (dict *SimpleDict) Keys() []string {
|
|||||||
i := 0
|
i := 0
|
||||||
for k := range dict.m {
|
for k := range dict.m {
|
||||||
result[i] = k
|
result[i] = k
|
||||||
|
i++
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
@@ -2,18 +2,34 @@ package dict
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/hdt3213/godis/lib/utils"
|
"github.com/hdt3213/godis/lib/utils"
|
||||||
|
"sort"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSimpleDict_Keys(t *testing.T) {
|
func TestSimpleDict_Keys(t *testing.T) {
|
||||||
d := MakeSimple()
|
d := MakeSimple()
|
||||||
size := 10
|
size := 10
|
||||||
|
var expectKeys []string
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
d.Put(utils.RandString(5), utils.RandString(5))
|
str := utils.RandString(5)
|
||||||
|
d.Put(str, str)
|
||||||
|
expectKeys = append(expectKeys, str)
|
||||||
}
|
}
|
||||||
if len(d.Keys()) != size {
|
sort.Slice(expectKeys, func(i, j int) bool {
|
||||||
|
return expectKeys[i] > expectKeys[j]
|
||||||
|
})
|
||||||
|
keys := d.Keys()
|
||||||
|
if len(keys) != size {
|
||||||
t.Errorf("expect %d keys, actual: %d", size, len(d.Keys()))
|
t.Errorf("expect %d keys, actual: %d", size, len(d.Keys()))
|
||||||
}
|
}
|
||||||
|
sort.Slice(keys, func(i, j int) bool {
|
||||||
|
return keys[i] > keys[j]
|
||||||
|
})
|
||||||
|
for i, k := range keys {
|
||||||
|
if k != expectKeys[i] {
|
||||||
|
t.Errorf("expect %s actual %s", expectKeys[i], k)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSimpleDict_PutIfExists(t *testing.T) {
|
func TestSimpleDict_PutIfExists(t *testing.T) {
|
||||||
|
@@ -19,6 +19,8 @@ type Connection interface {
|
|||||||
EnqueueCmd([][]byte)
|
EnqueueCmd([][]byte)
|
||||||
ClearQueuedCmds()
|
ClearQueuedCmds()
|
||||||
GetWatching() map[string]uint32
|
GetWatching() map[string]uint32
|
||||||
|
AddTxError(err error)
|
||||||
|
GetTxErrors() []error
|
||||||
|
|
||||||
// used for multi database
|
// used for multi database
|
||||||
GetDBIndex() int
|
GetDBIndex() int
|
||||||
|
@@ -31,7 +31,7 @@ func (w *Wait) WaitWithTimeout(timeout time.Duration) bool {
|
|||||||
c := make(chan bool, 1)
|
c := make(chan bool, 1)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(c)
|
defer close(c)
|
||||||
w.wg.Wait()
|
w.Wait()
|
||||||
c <- true
|
c <- true
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
|
@@ -35,6 +35,7 @@ type Connection struct {
|
|||||||
multiState bool
|
multiState bool
|
||||||
queue [][][]byte
|
queue [][][]byte
|
||||||
watching map[string]uint32
|
watching map[string]uint32
|
||||||
|
txErrors []error
|
||||||
|
|
||||||
// selected db
|
// selected db
|
||||||
selectedDB int
|
selectedDB int
|
||||||
@@ -65,11 +66,9 @@ func (c *Connection) Write(b []byte) error {
|
|||||||
if len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
c.mu.Lock()
|
|
||||||
c.waitingReply.Add(1)
|
c.waitingReply.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
c.waitingReply.Done()
|
c.waitingReply.Done()
|
||||||
c.mu.Unlock()
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, err := c.conn.Write(b)
|
_, err := c.conn.Write(b)
|
||||||
@@ -151,6 +150,16 @@ func (c *Connection) EnqueueCmd(cmdLine [][]byte) {
|
|||||||
c.queue = append(c.queue, cmdLine)
|
c.queue = append(c.queue, cmdLine)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddTxError stores syntax error within transaction
|
||||||
|
func (c *Connection) AddTxError(err error) {
|
||||||
|
c.txErrors = append(c.txErrors, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTxErrors returns syntax error within transaction
|
||||||
|
func (c *Connection) GetTxErrors() []error {
|
||||||
|
return c.txErrors
|
||||||
|
}
|
||||||
|
|
||||||
// ClearQueuedCmds clears queued commands of current transaction
|
// ClearQueuedCmds clears queued commands of current transaction
|
||||||
func (c *Connection) ClearQueuedCmds() {
|
func (c *Connection) ClearQueuedCmds() {
|
||||||
c.queue = nil
|
c.queue = nil
|
||||||
|
Reference in New Issue
Block a user