optimize ttl by time wheel

This commit is contained in:
hdt3213
2021-04-03 21:53:23 +08:00
parent f1277a6e2d
commit 85ef804d48
6 changed files with 207 additions and 37 deletions

View File

@@ -32,8 +32,8 @@ const (
CreatedStatus = 0
PreparedStatus = 1
CommitedStatus = 2
RollbackedStatus = 3
CommittedStatus = 2
RolledBackStatus = 3
)
func NewTransaction(cluster *Cluster, c redis.Connection, id string, args [][]byte, keys []string) *Transaction {
@@ -89,10 +89,10 @@ func (tx *Transaction) rollback() error {
tx.cluster.db.Remove(key)
}
}
if tx.status != CommitedStatus {
if tx.status != CommittedStatus {
tx.cluster.db.UnLocks(tx.keys...)
}
tx.status = RollbackedStatus
tx.status = RolledBackStatus
return nil
}
@@ -129,7 +129,7 @@ func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
// finish transaction
defer func() {
cluster.db.UnLocks(tx.keys...)
tx.status = CommitedStatus
tx.status = CommittedStatus
//cluster.transactions.Remove(tx.id) // cannot remove, may rollback after commit
}()
@@ -144,7 +144,7 @@ func Commit(cluster *Cluster, c redis.Connection, args [][]byte) redis.Reply {
if reply.IsErrorReply(result) {
// failed
err2 := tx.rollback()
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
return reply.MakeErrReply(fmt.Sprintf("err occurs when rollback: %v, origin err: %s", err2, result))
}
return result

View File

@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/HDT3213/godis/src/config"
"github.com/HDT3213/godis/src/datastruct/dict"
List "github.com/HDT3213/godis/src/datastruct/list"
"github.com/HDT3213/godis/src/datastruct/lock"
"github.com/HDT3213/godis/src/interface/redis"
"github.com/HDT3213/godis/src/lib/logger"
"github.com/HDT3213/godis/src/lib/timewheel"
"github.com/HDT3213/godis/src/pubsub"
"github.com/HDT3213/godis/src/redis/reply"
"os"
@@ -239,14 +239,26 @@ func (db *DB) RUnLocks(keys ...string) {
/* ---- TTL Functions ---- */
func genExpireTask(key string) string {
return "expire:" + key
}
func (db *DB) Expire(key string, expireTime time.Time) {
db.stopWorld.Wait()
db.TTLMap.Put(key, expireTime)
taskKey := genExpireTask(key)
timewheel.At(expireTime, taskKey, func() {
logger.Info("expire " + key)
db.TTLMap.Remove(key)
db.Data.Remove(key)
})
}
func (db *DB) Persist(key string) {
db.stopWorld.Wait()
db.TTLMap.Remove(key)
taskKey := genExpireTask(key)
timewheel.Cancel(taskKey)
}
func (db *DB) IsExpired(key string) bool {
@@ -262,32 +274,7 @@ func (db *DB) IsExpired(key string) bool {
return expired
}
func (db *DB) CleanExpired() {
now := time.Now()
toRemove := &List.LinkedList{}
db.TTLMap.ForEach(func(key string, val interface{}) bool {
expireTime, _ := val.(time.Time)
if now.After(expireTime) {
// expired
db.Data.Remove(key)
toRemove.Add(key)
}
return true
})
toRemove.ForEach(func(i int, val interface{}) bool {
key, _ := val.(string)
db.TTLMap.Remove(key)
return true
})
}
func (db *DB) TimerTask() {
ticker := time.NewTicker(db.interval)
go func() {
for range ticker.C {
db.CleanExpired()
}
}()
}
/* ---- Subscribe Functions ---- */

View File

@@ -289,7 +289,7 @@ func Persist(db *DB, args [][]byte) redis.Reply {
return reply.MakeIntReply(0)
}
db.TTLMap.Remove(key)
db.Persist(key)
db.AddAof(makeAofCmd("persist", args))
return reply.MakeIntReply(1)
}

View File

@@ -87,10 +87,7 @@ func TestRenameNx(t *testing.T) {
newKey := key + RandString(2)
Set(testDB, toArgs(key, value, "ex", "1000"))
result := RenameNx(testDB, toArgs(key, newKey))
if _, ok := result.(*reply.OkReply); !ok {
t.Error("expect ok")
return
}
asserts.AssertIntReply(t, result, 1)
result = Exists(testDB, toArgs(key))
asserts.AssertIntReply(t, result, 0)
result = Exists(testDB, toArgs(newKey))

View File

@@ -0,0 +1,21 @@
package timewheel
import "time"
var tw = New(time.Second, 3600)
func init() {
tw.Start()
}
func Delay(duration time.Duration, key string, job func()) {
tw.AddTimer(duration, key, job)
}
func At(at time.Time, key string, job func()) {
tw.AddTimer(at.Sub(time.Now()), key, job)
}
func Cancel(key string) {
tw.RemoveTimer(key)
}

View File

@@ -0,0 +1,165 @@
package timewheel
import (
"container/list"
"github.com/HDT3213/godis/src/lib/logger"
"time"
)
type TimeWheel struct {
interval time.Duration
ticker *time.Ticker
slots []*list.List
timer map[string]int
currentPos int
slotNum int
addTaskChannel chan Task
removeTaskChannel chan string
stopChannel chan bool
}
type Task struct {
delay time.Duration
circle int
key string
job func()
}
func New(interval time.Duration, slotNum int) *TimeWheel {
if interval <= 0 || slotNum <= 0 {
return nil
}
tw := &TimeWheel{
interval: interval,
slots: make([]*list.List, slotNum),
timer: make(map[string]int),
currentPos: 0,
slotNum: slotNum,
addTaskChannel: make(chan Task),
removeTaskChannel: make(chan string),
stopChannel: make(chan bool),
}
tw.initSlots()
return tw
}
func (tw *TimeWheel) initSlots() {
for i := 0; i < tw.slotNum; i++ {
tw.slots[i] = list.New()
}
}
func (tw *TimeWheel) Start() {
tw.ticker = time.NewTicker(tw.interval)
go tw.start()
}
func (tw *TimeWheel) Stop() {
tw.stopChannel <- true
}
func (tw *TimeWheel) AddTimer(delay time.Duration, key string, job func()) {
if delay < 0 {
return
}
tw.addTaskChannel <- Task{delay: delay, key: key, job: job}
}
func (tw *TimeWheel) RemoveTimer(key string) {
if key == "" {
return
}
tw.removeTaskChannel <- key
}
func (tw *TimeWheel) start() {
for {
select {
case <-tw.ticker.C:
tw.tickHandler()
case task := <-tw.addTaskChannel:
tw.addTask(&task)
case key := <-tw.removeTaskChannel:
tw.removeTask(key)
case <-tw.stopChannel:
tw.ticker.Stop()
return
}
}
}
func (tw *TimeWheel) tickHandler() {
l := tw.slots[tw.currentPos]
tw.scanAndRunTask(l)
if tw.currentPos == tw.slotNum-1 {
tw.currentPos = 0
} else {
tw.currentPos++
}
}
func (tw *TimeWheel) scanAndRunTask(l *list.List) {
for e := l.Front(); e != nil; {
task := e.Value.(*Task)
if task.circle > 0 {
task.circle--
e = e.Next()
continue
}
go func() {
defer func() {
if err := recover(); err != nil {
logger.Error(err)
}
}()
job := task.job
job()
}()
next := e.Next()
l.Remove(e)
if task.key != "" {
delete(tw.timer, task.key)
}
e = next
}
}
func (tw *TimeWheel) addTask(task *Task) {
pos, circle := tw.getPositionAndCircle(task.delay)
task.circle = circle
tw.slots[pos].PushBack(task)
if task.key != "" {
tw.timer[task.key] = pos
}
}
func (tw *TimeWheel) getPositionAndCircle(d time.Duration) (pos int, circle int) {
delaySeconds := int(d.Seconds())
intervalSeconds := int(tw.interval.Seconds())
circle = int(delaySeconds / intervalSeconds / tw.slotNum)
pos = int(tw.currentPos+delaySeconds/intervalSeconds) % tw.slotNum
return
}
func (tw *TimeWheel) removeTask(key string) {
position, ok := tw.timer[key]
if !ok {
return
}
l := tw.slots[position]
for e := l.Front(); e != nil; {
task := e.Value.(*Task)
if task.key == key {
delete(tw.timer, task.key)
l.Remove(e)
}
e = e.Next()
}
}