diff --git a/src/datastruct/lock/lock_map.go b/src/datastruct/lock/lock_map.go new file mode 100644 index 0000000..3ccdd25 --- /dev/null +++ b/src/datastruct/lock/lock_map.go @@ -0,0 +1,131 @@ +package lock + +import ( + "fmt" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "time" +) + +type LockMap struct { + m sync.Map // key -> mutex +} + +func (lock *LockMap)Lock(key string) { + mu := &sync.RWMutex{} + existed, loaded := lock.m.LoadOrStore(key, mu) + if loaded { + mu, _ = existed.(*sync.RWMutex) + } + mu.Lock() +} + +func (lock *LockMap)RLock(key string) { + mu := &sync.RWMutex{} + existed, loaded := lock.m.LoadOrStore(key, mu) + if loaded { + mu, _ = existed.(*sync.RWMutex) + } + mu.RLock() +} + +func (lock *LockMap)UnLock(key string) { + value, ok := lock.m.Load(key) + if !ok { + return + } + mu := value.(*sync.RWMutex) + mu.Unlock() +} + +func (lock *LockMap)RUnLock(key string) { + value, ok := lock.m.Load(key) + if !ok { + return + } + mu := value.(*sync.RWMutex) + mu.RUnlock() +} + +func (lock *LockMap)Locks(keys ...string) { + keySlice := make(sort.StringSlice, len(keys)) + copy(keySlice, keys) + sort.Sort(keySlice) + for _, key := range keySlice { + lock.Lock(key) + } +} + +func (lock *LockMap)RLocks(keys ...string) { + keySlice := make(sort.StringSlice, len(keys)) + copy(keySlice, keys) + sort.Sort(keySlice) + for _, key := range keySlice { + lock.RLock(key) + } +} + + +func (lock *LockMap)UnLocks(keys ...string) { + size := len(keys) + keySlice := make(sort.StringSlice, size) + copy(keySlice, keys) + sort.Sort(keySlice) + for i := size - 1; i >= 0; i-- { + key := keySlice[i] + lock.UnLock(key) + } +} + +func (lock *LockMap)RUnLocks(keys ...string) { + size := len(keys) + keySlice := make(sort.StringSlice, size) + copy(keySlice, keys) + sort.Sort(keySlice) + for i := size - 1; i >= 0; i-- { + key := keySlice[i] + lock.RUnLock(key) + } +} + +func (lock *LockMap)Clean(key string) { + lock.m.Delete(key) +} + +func (lock *LockMap)Cleans(keys ...string) { + for _, key := range keys { + lock.Clean(key) + } +} + +func GoID() int { + var buf [64]byte + n := runtime.Stack(buf[:], false) + idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] + id, err := strconv.Atoi(idField) + if err != nil { + panic(fmt.Sprintf("cannot get goroutine id: %v", err)) + } + return id +} + +func debug() { + lm := LockMap{} + size := 10 + var wg sync.WaitGroup + wg.Add(size) + for i := 0; i < size; i++ { + go func(i int) { + lm.Locks("1", "2") + println("go: " + strconv.Itoa(GoID())) + time.Sleep(time.Second) + println("go: " + strconv.Itoa(GoID())) + lm.UnLocks("1", "2") + wg.Done() + }(i) + } + wg.Wait() +} \ No newline at end of file diff --git a/src/db/db.go b/src/db/db.go index 5bbeb5c..6111d67 100644 --- a/src/db/db.go +++ b/src/db/db.go @@ -3,12 +3,12 @@ package db import ( "fmt" "github.com/HDT3213/godis/src/datastruct/dict" + "github.com/HDT3213/godis/src/datastruct/lock" "github.com/HDT3213/godis/src/interface/redis" "github.com/HDT3213/godis/src/lib/logger" "github.com/HDT3213/godis/src/redis/reply" "runtime/debug" "strings" - "sync" ) const ( @@ -23,10 +23,6 @@ type DataEntity struct { Code uint8 TTL int64 // ttl in seconds, 0 for unlimited ttl Data interface{} - - // dict will ensure thread safety (by using mutex) of its method - // use this mutex for complicated command only, eg. rpush, incr ... - sync.RWMutex } type DataEntityWithKey struct { @@ -38,7 +34,12 @@ type DataEntityWithKey struct { type CmdFunc func(db *DB, args [][]byte)redis.Reply type DB struct { - Data *dict.Dict // key -> DataEntity + // key -> DataEntity + Data *dict.Dict + + // dict will ensure thread safety of its method + // use this mutex for complicated command only, eg. rpush, incr ... + Locks *lock.LockMap } var cmdMap = MakeCmdMap() @@ -53,6 +54,7 @@ func MakeCmdMap()map[string]CmdFunc { cmdMap["psetex"] = PSetEX cmdMap["mset"] = MSet cmdMap["mget"] = MGet + cmdMap["msetnx"] = MSetNX cmdMap["get"] = Get cmdMap["del"] = Del @@ -75,6 +77,7 @@ func MakeCmdMap()map[string]CmdFunc { func MakeDB() *DB { return &DB{ Data: dict.Make(1024), + Locks: &lock.LockMap{}, } } @@ -98,3 +101,25 @@ func (db *DB)Exec(args [][]byte)(result redis.Reply) { } return } + +func (db *DB)Remove(key string) { + db.Data.Remove(key) + db.Locks.Clean(key) +} + +func (db *DB)Removes(keys ...string)(deleted int) { + db.Locks.Locks(keys...) + defer func() { + db.Locks.UnLocks(keys...) + db.Locks.Cleans(keys...) + }() + deleted = 0 + for _, key := range keys { + _, exists := db.Data.Get(key) + if exists { + db.Data.Remove(key) + deleted++ + } + } + return deleted +} diff --git a/src/db/get.go b/src/db/get.go deleted file mode 100644 index 297c459..0000000 --- a/src/db/get.go +++ /dev/null @@ -1,27 +0,0 @@ -package db - -import ( - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - -func Get(db *DB, args [][]byte)redis.Reply { - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'get' command") - } - key := string(args[0]) - val, ok := db.Data.Get(key) - if !ok { - return &reply.NullBulkReply{} - } - entity, _ := val.(*DataEntity) - if entity.Code == StringCode { - bytes, ok := entity.Data.([]byte) - if !ok { - return &reply.UnknownErrReply{} - } - return reply.MakeBulkReply(bytes) - } else { - return &reply.WrongTypeErrReply{} - } -} \ No newline at end of file diff --git a/src/db/del.go b/src/db/keys.go similarity index 70% rename from src/db/del.go rename to src/db/keys.go index e90a33a..259909e 100644 --- a/src/db/del.go +++ b/src/db/keys.go @@ -14,14 +14,6 @@ func Del(db *DB, args [][]byte)redis.Reply { keys[i] = string(v) } - deleted := 0 - for _, key := range keys { - _, exists := db.Data.Get(key) - if exists { - db.Data.Remove(key) - deleted++ - } - } - + deleted := db.Removes(keys...) return reply.MakeIntReply(int64(deleted)) } diff --git a/src/db/lindex.go b/src/db/lindex.go deleted file mode 100644 index bfd431f..0000000 --- a/src/db/lindex.go +++ /dev/null @@ -1,49 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" -) - -func LIndex(db *DB, args [][]byte)redis.Reply { - // parse args - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") - } - key := string(args[0]) - index64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - index := int(index64) - - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return &reply.NullBulkReply{} - } else { - entity, _ = rawEntity.(*DataEntity) - } - entity.RLock() - defer entity.RUnlock() - - // check type - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - - list, _ := entity.Data.(*List.LinkedList) - size := list.Len() // assert: size > 0 - if index < -1 * size { - return &reply.NullBulkReply{} - } else if index < 0 { - index = size + index - } else if index >= size { - return &reply.NullBulkReply{} - } - - val, _ := list.Get(index).([]byte) - return reply.MakeBulkReply(val) -} diff --git a/src/db/list.go b/src/db/list.go new file mode 100644 index 0000000..51fdfe1 --- /dev/null +++ b/src/db/list.go @@ -0,0 +1,476 @@ +package db + +import ( + List "github.com/HDT3213/godis/src/datastruct/list" + "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" + "strconv" +) + +func LIndex(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") + } + key := string(args[0]) + index64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + index := int(index64) + + // get entity + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return &reply.NullBulkReply{} + } else { + entity, _ = rawEntity.(*DataEntity) + } + + // check type + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + list, _ := entity.Data.(*List.LinkedList) + size := list.Len() // assert: size > 0 + if index < -1 * size { + return &reply.NullBulkReply{} + } else if index < 0 { + index = size + index + } else if index >= size { + return &reply.NullBulkReply{} + } + + val, _ := list.Get(index).([]byte) + return reply.MakeBulkReply(val) +} + +func LLen(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command") + } + key := string(args[0]) + + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return reply.MakeIntReply(0) + } else { + entity, _ = rawEntity.(*DataEntity) + } + + // check type + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + list, _ := entity.Data.(*List.LinkedList) + size := int64(list.Len()) + return reply.MakeIntReply(size) +} + +func LPop(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") + } + key := string(args[0]) + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get data + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return &reply.NullBulkReply{} + } else { + entity, _ = rawEntity.(*DataEntity) + } + + // check type + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + // remove + list, _ := entity.Data.(*List.LinkedList) + val, _ := list.Remove(0).([]byte) + if list.Len() == 0 { + db.Remove(key) + } + return reply.MakeBulkReply(val) +} + +func LPush(db *DB, args [][]byte)redis.Reply { + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command") + } + key := string(args[0]) + values := args[1:] + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get or init entity + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + entity = &DataEntity{ + Code: ListCode, + Data: &List.LinkedList{}, + } + } else { + entity, _ = rawEntity.(*DataEntity) + } + + // check type + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + // insert + list, _ := entity.Data.(*List.LinkedList) + for _, value := range values { + list.Insert(0, value) + } + db.Data.Put(key, entity) + + return reply.MakeIntReply(int64(list.Len())) +} + +func LPushX(db *DB, args [][]byte)redis.Reply { + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command") + } + key := string(args[0]) + values := args[1:] + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get or init entity + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return reply.MakeIntReply(0) + } else { + entity, _ = rawEntity.(*DataEntity) + } + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + // insert + list, _ := entity.Data.(*List.LinkedList) + for _, value := range values { + list.Insert(0, value) + } + db.Data.Put(key, entity) + + return reply.MakeIntReply(int64(list.Len())) +} + +func LRange(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command") + } + key := string(args[0]) + start64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + start := int(start64) + stop64, err := strconv.ParseInt(string(args[2]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + stop := int(stop64) + + // get data + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return &reply.EmptyMultiBulkReply{} + } else { + entity, _ = rawEntity.(*DataEntity) + } + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + // compute index + list, _ := entity.Data.(*List.LinkedList) + size := list.Len() // assert: size > 0 + if start < -1 * size { + start = 0 + } else if start < 0 { + start = size + start + } else if start >= size { + return &reply.EmptyMultiBulkReply{} + } + if stop < -1 * size { + stop = 0 + } else if stop < 0 { + stop = size + stop + 1 + } else if stop < size { + stop = stop + 1 + } else { + stop = size + } + if stop < start { + stop = start + } + + // assert: start in [0, size - 1], stop in [start, size] + slice := list.Range(start, stop) + result := make([][]byte, len(slice)) + for i, raw := range slice { + bytes, _ := raw.([]byte) + result[i] = bytes + } + return reply.MakeMultiBulkReply(result) +} + +func LRem(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command") + } + key := string(args[0]) + count64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + count := int(count64) + value := args[2] + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get data entity + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return reply.MakeIntReply(0) + } else { + entity, _ = rawEntity.(*DataEntity) + } + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + list, _ := entity.Data.(*List.LinkedList) + var removed int + if count == 0 { + removed = list.RemoveAllByVal(value) + } else if count > 0 { + removed = list.RemoveByVal(value, count) + } else { + removed = list.ReverseRemoveByVal(value, -count) + } + + if list.Len() == 0 { + db.Remove(key) + } + + return reply.MakeIntReply(int64(removed)) +} + +func LSet(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 3 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command") + } + key := string(args[0]) + index64, err := strconv.ParseInt(string(args[1]), 10, 64) + if err != nil { + return reply.MakeErrReply("ERR value is not an integer or out of range") + } + index := int(index64) + value := args[2] + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get data + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return reply.MakeErrReply("ERR no such key") + } else { + entity, _ = rawEntity.(*DataEntity) + } + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + list, _ := entity.Data.(*List.LinkedList) + size := list.Len() // assert: size > 0 + if index < -1 * size { + return reply.MakeErrReply("ERR index out of range") + } else if index < 0 { + index = size + index + } else if index >= size { + return reply.MakeErrReply("ERR index out of range") + } + + list.Set(index, value) + return &reply.OkReply{} +} + +func RPop(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") + } + key := string(args[0]) + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get data + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return &reply.NullBulkReply{} + } else { + entity, _ = rawEntity.(*DataEntity) + } + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + list, _ := entity.Data.(*List.LinkedList) + val, _ := list.RemoveLast().([]byte) + if list.Len() == 0 { + db.Remove(key) + } + return reply.MakeBulkReply(val) +} + +func RPopLPush(db *DB, args [][]byte)redis.Reply { + if len(args) != 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command") + } + sourceKey := string(args[0]) + destKey := string(args[1]) + + // lock + db.Locks.Locks(sourceKey, destKey) + defer db.Locks.UnLocks(sourceKey, destKey) + + // get source entity + rawEntity, exists := db.Data.Get(sourceKey) + var sourceEntity *DataEntity + if !exists { + return &reply.NullBulkReply{} + } else { + sourceEntity, _ = rawEntity.(*DataEntity) + } + sourceList, _ := sourceEntity.Data.(*List.LinkedList) + + // get dest entity + rawEntity, exists = db.Data.Get(destKey) + var destEntity *DataEntity + if !exists { + destEntity = &DataEntity{ + Code: ListCode, + Data: &List.LinkedList{}, + } + db.Data.Put(destKey, destEntity) + } else { + destEntity, _ = rawEntity.(*DataEntity) + } + destList, _ := destEntity.Data.(*List.LinkedList) + + // pop and push + val, _ := sourceList.RemoveLast().([]byte) + destList.Insert(0, val) + + if sourceList.Len() == 0 { + db.Remove(sourceKey) + } + + return reply.MakeBulkReply(val) +} + +func RPush(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") + } + key := string(args[0]) + values := args[1:] + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get or init entity + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + entity = &DataEntity{ + Code: ListCode, + Data: &List.LinkedList{}, + } + } else { + entity, _ = rawEntity.(*DataEntity) + } + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + // put list + list, _ := entity.Data.(*List.LinkedList) + for _, value := range values { + list.Add(value) + } + db.Data.Put(key, entity) + + return reply.MakeIntReply(int64(list.Len())) +} + +func RPushX(db *DB, args [][]byte)redis.Reply { + if len(args) < 2 { + return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") + } + key := string(args[0]) + values := args[1:] + + // lock + db.Locks.Lock(key) + defer db.Locks.UnLock(key) + + // get or init entity + rawEntity, exists := db.Data.Get(key) + var entity *DataEntity + if !exists { + return reply.MakeIntReply(0) + } else { + entity, _ = rawEntity.(*DataEntity) + } + if entity.Code != ListCode { + return &reply.WrongTypeErrReply{} + } + + // put list + list, _ := entity.Data.(*List.LinkedList) + for _, value := range values { + list.Add(value) + } + db.Data.Put(key, entity) + + return reply.MakeIntReply(int64(list.Len())) +} \ No newline at end of file diff --git a/src/db/llen.go b/src/db/llen.go deleted file mode 100644 index 3809286..0000000 --- a/src/db/llen.go +++ /dev/null @@ -1,34 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - -func LLen(db *DB, args [][]byte)redis.Reply { - // parse args - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'llen' command") - } - key := string(args[0]) - - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return reply.MakeIntReply(0) - } else { - entity, _ = rawEntity.(*DataEntity) - } - entity.RLock() - defer entity.RUnlock() - - // check type - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - - list, _ := entity.Data.(*List.LinkedList) - size := int64(list.Len()) - return reply.MakeIntReply(size) -} \ No newline at end of file diff --git a/src/db/lpop.go b/src/db/lpop.go deleted file mode 100644 index bf6d937..0000000 --- a/src/db/lpop.go +++ /dev/null @@ -1,38 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - -func LPop(db *DB, args [][]byte)redis.Reply { - // parse args - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") - } - key := string(args[0]) - - // get data - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return &reply.NullBulkReply{} - } else { - entity, _ = rawEntity.(*DataEntity) - } - entity.Lock() - defer entity.Unlock() - - // check type - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - - list, _ := entity.Data.(*List.LinkedList) - val, _ := list.Remove(0).([]byte) - if list.Len() == 0 { - db.Data.Remove(key) - } - return reply.MakeBulkReply(val) -} diff --git a/src/db/lpush.go b/src/db/lpush.go deleted file mode 100644 index 8fd7ee2..0000000 --- a/src/db/lpush.go +++ /dev/null @@ -1,73 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - -func LPush(db *DB, args [][]byte)redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command") - } - key := string(args[0]) - values := args[1:] - - // get or init entity - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - entity = &DataEntity{ - Code: ListCode, - Data: &List.LinkedList{}, - } - } else { - entity, _ = rawEntity.(*DataEntity) - } - entity.Lock() - defer entity.Unlock() - - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - - // insert - list, _ := entity.Data.(*List.LinkedList) - for _, value := range values { - list.Insert(0, value) - } - db.Data.Put(key, entity) - - return reply.MakeIntReply(int64(list.Len())) -} - -func LPushX(db *DB, args [][]byte)redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lpush' command") - } - key := string(args[0]) - values := args[1:] - - // get or init entity - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return reply.MakeIntReply(0) - } else { - entity, _ = rawEntity.(*DataEntity) - } - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - entity.Lock() - defer entity.Unlock() - - // insert - list, _ := entity.Data.(*List.LinkedList) - for _, value := range values { - list.Insert(0, value) - } - db.Data.Put(key, entity) - - return reply.MakeIntReply(int64(list.Len())) -} \ No newline at end of file diff --git a/src/db/lrange.go b/src/db/lrange.go deleted file mode 100644 index 56e01b1..0000000 --- a/src/db/lrange.go +++ /dev/null @@ -1,72 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" -) - -func LRange(db *DB, args [][]byte)redis.Reply { - // parse args - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lrange' command") - } - key := string(args[0]) - start64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - start := int(start64) - stop64, err := strconv.ParseInt(string(args[2]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - stop := int(stop64) - - // get data - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return &reply.EmptyMultiBulkReply{} - } else { - entity, _ = rawEntity.(*DataEntity) - } - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - entity.RLock() - defer entity.RUnlock() - - // compute index - list, _ := entity.Data.(*List.LinkedList) - size := list.Len() // assert: size > 0 - if start < -1 * size { - start = 0 - } else if start < 0 { - start = size + start - } else if start >= size { - return &reply.EmptyMultiBulkReply{} - } - if stop < -1 * size { - stop = 0 - } else if stop < 0 { - stop = size + stop + 1 - } else if stop < size { - stop = stop + 1 - } else { - stop = size - } - if stop < start { - stop = start - } - - // assert: start in [0, size - 1], stop in [start, size] - slice := list.Range(start, stop) - result := make([][]byte, len(slice)) - for i, raw := range slice { - bytes, _ := raw.([]byte) - result[i] = bytes - } - return reply.MakeMultiBulkReply(result) -} diff --git a/src/db/lrem.go b/src/db/lrem.go deleted file mode 100644 index 42cfb54..0000000 --- a/src/db/lrem.go +++ /dev/null @@ -1,52 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" -) - -func LRem(db *DB, args [][]byte)redis.Reply { - // parse args - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lrem' command") - } - key := string(args[0]) - count64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - count := int(count64) - value := args[2] - - // get data entity - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return reply.MakeIntReply(0) - } else { - entity, _ = rawEntity.(*DataEntity) - } - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - entity.Lock() - defer entity.Unlock() - - list, _ := entity.Data.(*List.LinkedList) - var removed int - if count == 0 { - removed = list.RemoveAllByVal(value) - } else if count > 0 { - removed = list.RemoveByVal(value, count) - } else { - removed = list.ReverseRemoveByVal(value, -count) - } - - if list.Len() == 0 { - db.Data.Remove(key) - } - - return reply.MakeIntReply(int64(removed)) -} \ No newline at end of file diff --git a/src/db/lset.go b/src/db/lset.go deleted file mode 100644 index 14623f8..0000000 --- a/src/db/lset.go +++ /dev/null @@ -1,49 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" - "strconv" -) - -func LSet(db *DB, args [][]byte)redis.Reply { - // parse args - if len(args) != 3 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lset' command") - } - key := string(args[0]) - index64, err := strconv.ParseInt(string(args[1]), 10, 64) - if err != nil { - return reply.MakeErrReply("ERR value is not an integer or out of range") - } - index := int(index64) - value := args[2] - - // get data - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return reply.MakeErrReply("ERR no such key") - } else { - entity, _ = rawEntity.(*DataEntity) - } - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - entity.Lock() - defer entity.Unlock() - - list, _ := entity.Data.(*List.LinkedList) - size := list.Len() // assert: size > 0 - if index < -1 * size { - return reply.MakeErrReply("ERR index out of range") - } else if index < 0 { - index = size + index - } else if index >= size { - return reply.MakeErrReply("ERR index out of range") - } - - list.Set(index, value) - return &reply.OkReply{} -} \ No newline at end of file diff --git a/src/db/mget.go b/src/db/mget.go deleted file mode 100644 index 572ec5e..0000000 --- a/src/db/mget.go +++ /dev/null @@ -1,34 +0,0 @@ -package db - -import ( - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - -func MGet(db *DB, args [][]byte)redis.Reply { - if len(args) == 0 { - return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") - } - keys := make([]string, len(args)) - for i, v := range args { - keys[i] = string(v) - } - - result := make([][]byte, len(args)) - for i, key := range keys { - val, exists := db.Data.Get(key) - if !exists { - result[i] = nil - continue - } - entity, _ := val.(*DataEntity) - if entity.Code != StringCode { - result[i] = nil - continue - } - bytes, _ := entity.Data.([]byte) - result[i] = bytes - } - - return reply.MakeMultiBulkReply(result) -} \ No newline at end of file diff --git a/src/db/mset.go b/src/db/mset.go deleted file mode 100644 index 01e0253..0000000 --- a/src/db/mset.go +++ /dev/null @@ -1,33 +0,0 @@ -package db - -import ( - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - - -func MSet(db *DB, args [][]byte)redis.Reply { - if len(args) % 2 != 0 || len(args) == 0 { - return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") - } - size := len(args) / 2 - entities := make([]*DataEntityWithKey, size) - for i := 0; i < size; i++ { - key := string(args[2 * i]) - value := args[2 * i + 1] - entity := &DataEntityWithKey{ - DataEntity: DataEntity{ - Code: StringCode, - Data: value, - }, - Key: key, - } - entities[i] = entity - } - - for _, entity := range entities { - db.Data.Put(entity.Key, &entity.DataEntity) - } - - return &reply.OkReply{} -} diff --git a/src/db/rpop.go b/src/db/rpop.go deleted file mode 100644 index c3355e5..0000000 --- a/src/db/rpop.go +++ /dev/null @@ -1,36 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - -func RPop(db *DB, args [][]byte)redis.Reply { - // parse args - if len(args) != 1 { - return reply.MakeErrReply("ERR wrong number of arguments for 'lindex' command") - } - key := string(args[0]) - - // get data - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return &reply.NullBulkReply{} - } else { - entity, _ = rawEntity.(*DataEntity) - } - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - entity.Lock() - defer entity.Unlock() - - list, _ := entity.Data.(*List.LinkedList) - val, _ := list.RemoveLast().([]byte) - if list.Len() == 0 { - db.Data.Remove(key) - } - return reply.MakeBulkReply(val) -} diff --git a/src/db/rpoplpush.go b/src/db/rpoplpush.go deleted file mode 100644 index a71473a..0000000 --- a/src/db/rpoplpush.go +++ /dev/null @@ -1,49 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - -func RPopLPush(db *DB, args [][]byte)redis.Reply { - if len(args) != 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rpoplpush' command") - } - sourceKey := string(args[0]) - destKey := string(args[1]) - - // get source entity - rawEntity, exists := db.Data.Get(sourceKey) - var sourceEntity *DataEntity - if !exists { - return &reply.NullBulkReply{} - } else { - sourceEntity, _ = rawEntity.(*DataEntity) - } - sourceList, _ := sourceEntity.Data.(*List.LinkedList) - sourceEntity.Lock() - defer sourceEntity.Unlock() - - // get dest entity - rawEntity, exists = db.Data.Get(destKey) - var destEntity *DataEntity - if !exists { - destEntity = &DataEntity{ - Code: ListCode, - Data: &List.LinkedList{}, - } - db.Data.Put(destKey, destEntity) - } else { - destEntity, _ = rawEntity.(*DataEntity) - } - destList, _ := destEntity.Data.(*List.LinkedList) - destEntity.Lock() - defer destEntity.Unlock() - - // pop and push - val, _ := sourceList.RemoveLast().([]byte) - destList.Insert(0, val) - - return reply.MakeBulkReply(val) -} diff --git a/src/db/rpush.go b/src/db/rpush.go deleted file mode 100644 index 5283a50..0000000 --- a/src/db/rpush.go +++ /dev/null @@ -1,72 +0,0 @@ -package db - -import ( - List "github.com/HDT3213/godis/src/datastruct/list" - "github.com/HDT3213/godis/src/interface/redis" - "github.com/HDT3213/godis/src/redis/reply" -) - -func RPush(db *DB, args [][]byte)redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") - } - key := string(args[0]) - values := args[1:] - - // get or init entity - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - entity = &DataEntity{ - Code: ListCode, - Data: &List.LinkedList{}, - } - } else { - entity, _ = rawEntity.(*DataEntity) - } - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - entity.Lock() - defer entity.Unlock() - - // put list - list, _ := entity.Data.(*List.LinkedList) - for _, value := range values { - list.Add(value) - } - db.Data.Put(key, entity) - - return reply.MakeIntReply(int64(list.Len())) -} - -func RPushX(db *DB, args [][]byte)redis.Reply { - if len(args) < 2 { - return reply.MakeErrReply("ERR wrong number of arguments for 'rpush' command") - } - key := string(args[0]) - values := args[1:] - - // get or init entity - rawEntity, exists := db.Data.Get(key) - var entity *DataEntity - if !exists { - return reply.MakeIntReply(0) - } else { - entity, _ = rawEntity.(*DataEntity) - } - if entity.Code != ListCode { - return &reply.WrongTypeErrReply{} - } - entity.Lock() - defer entity.Unlock() - - // put list - list, _ := entity.Data.(*List.LinkedList) - for _, value := range values { - list.Add(value) - } - db.Data.Put(key, entity) - - return reply.MakeIntReply(int64(list.Len())) -} \ No newline at end of file diff --git a/src/db/ping.go b/src/db/server.go similarity index 99% rename from src/db/ping.go rename to src/db/server.go index 91809c0..8c6c5fe 100644 --- a/src/db/ping.go +++ b/src/db/server.go @@ -1,8 +1,8 @@ package db import ( - "github.com/HDT3213/godis/src/redis/reply" "github.com/HDT3213/godis/src/interface/redis" + "github.com/HDT3213/godis/src/redis/reply" ) func Ping(db *DB, args [][]byte)redis.Reply { @@ -13,4 +13,4 @@ func Ping(db *DB, args [][]byte)redis.Reply { } else { return reply.MakeErrReply("ERR wrong number of arguments for 'ping' command") } -} +} \ No newline at end of file diff --git a/src/db/set.go b/src/db/string.go similarity index 59% rename from src/db/set.go rename to src/db/string.go index 2a4db91..4495886 100644 --- a/src/db/set.go +++ b/src/db/string.go @@ -7,6 +7,27 @@ import ( "strings" ) +func Get(db *DB, args [][]byte)redis.Reply { + if len(args) != 1 { + return reply.MakeErrReply("ERR wrong number of arguments for 'get' command") + } + key := string(args[0]) + val, ok := db.Data.Get(key) + if !ok { + return &reply.NullBulkReply{} + } + entity, _ := val.(*DataEntity) + if entity.Code == StringCode { + bytes, ok := entity.Data.([]byte) + if !ok { + return &reply.UnknownErrReply{} + } + return reply.MakeBulkReply(bytes) + } else { + return &reply.WrongTypeErrReply{} + } +} + const ( upsertPolicy = iota // default insertPolicy // set nx @@ -156,4 +177,98 @@ func PSetEX(db *DB, args [][]byte)redis.Reply { } db.Data.PutIfExists(key, entity) return &reply.OkReply{} +} + +func MSet(db *DB, args [][]byte)redis.Reply { + if len(args) % 2 != 0 || len(args) == 0 { + return reply.MakeErrReply("ERR wrong number of arguments for 'mset' command") + } + size := len(args) / 2 + entities := make([]*DataEntityWithKey, size) + for i := 0; i < size; i++ { + key := string(args[2 * i]) + value := args[2 * i + 1] + entity := &DataEntityWithKey{ + DataEntity: DataEntity{ + Code: StringCode, + Data: value, + }, + Key: key, + } + entities[i] = entity + } + + for _, entity := range entities { + db.Data.Put(entity.Key, &entity.DataEntity) + } + + return &reply.OkReply{} +} + +func MGet(db *DB, args [][]byte)redis.Reply { + if len(args) == 0 { + return reply.MakeErrReply("ERR wrong number of arguments for 'mget' command") + } + keys := make([]string, len(args)) + for i, v := range args { + keys[i] = string(v) + } + + result := make([][]byte, len(args)) + for i, key := range keys { + val, exists := db.Data.Get(key) + if !exists { + result[i] = nil + continue + } + entity, _ := val.(*DataEntity) + if entity.Code != StringCode { + result[i] = nil + continue + } + bytes, _ := entity.Data.([]byte) + result[i] = bytes + } + + return reply.MakeMultiBulkReply(result) +} + +func MSetNX(db *DB, args [][]byte)redis.Reply { + // parse args + if len(args) % 2 != 0 || len(args) == 0 { + return reply.MakeErrReply("ERR wrong number of arguments for 'msetnx' command") + } + size := len(args) / 2 + entities := make([]*DataEntityWithKey, size) + keys := make([]string, size) + for i := 0; i < size; i++ { + key := string(args[2 * i]) + value := args[2 * i + 1] + entity := &DataEntityWithKey{ + DataEntity: DataEntity{ + Code: StringCode, + Data: value, + }, + Key: key, + } + entities[i] = entity + keys[i] = key + } + + // lock keys + db.Locks.Locks(keys...) + defer db.Locks.UnLocks(keys...) + + for _, key := range keys { + _, exists := db.Data.Get(key) + if exists { + return reply.MakeIntReply(0) + } + } + + for _, entity := range entities { + db.Data.Put(entity.Key, &entity.DataEntity) + } + + return reply.MakeIntReply(1) } \ No newline at end of file