refactor concurrent consume

This commit is contained in:
finley
2024-09-27 21:37:48 +08:00
parent 0812e5ac21
commit a752279fa9
4 changed files with 192 additions and 93 deletions

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"log"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
@@ -14,27 +14,31 @@ import (
// DelayQueue is a message queue supporting delayed/scheduled delivery based on redis
type DelayQueue struct {
// name for this Queue. Make sure the name is unique in redis database
name string
redisCli RedisCli
cb func(string) bool
pendingKey string // sorted set: message id -> delivery time
readyKey string // list
unAckKey string // sorted set: message id -> retry time
retryKey string // list
retryCountKey string // hash: message id -> remain retry count
garbageKey string // set: message id
useHashTag bool
ticker *time.Ticker
logger *log.Logger
close chan struct{}
name string
redisCli RedisCli
cb func(string) bool
pendingKey string // sorted set: message id -> delivery time
readyKey string // list
unAckKey string // sorted set: message id -> retry time
retryKey string // list
retryCountKey string // hash: message id -> remain retry count
garbageKey string // set: message id
useHashTag bool
ticker *time.Ticker
logger *log.Logger
close chan struct{}
running int32
maxConsumeDuration time.Duration // default 5 seconds
msgTTL time.Duration // default 1 hour
defaultRetryCount uint // default 3
fetchInterval time.Duration // default 1 second
fetchLimit uint // default no limit
fetchCount int32 // actually running task number
concurrent uint // default 1, executed serially
// for batch consume
consumeBuffer chan string
eventListener EventListener
}
@@ -163,8 +167,9 @@ func (q *DelayQueue) WithFetchLimit(limit uint) *DelayQueue {
// WithConcurrent sets the number of concurrent consumers
func (q *DelayQueue) WithConcurrent(c uint) *DelayQueue {
if c == 0 {
return q
panic("concurrent cannot be 0")
}
q.assertNotRunning()
q.concurrent = c
return q
}
@@ -339,44 +344,8 @@ func (q *DelayQueue) callback(idStr string) error {
return err
}
// batchCallback calls DelayQueue.callback in batch. callback is executed concurrently according to property DelayQueue.concurrent
// batchCallback must wait all callback finished, otherwise the actual number of processing messages may beyond DelayQueue.FetchLimit
func (q *DelayQueue) batchCallback(ids []string) {
if len(ids) == 1 || q.concurrent == 1 {
for _, id := range ids {
err := q.callback(id)
if err != nil {
q.logger.Printf("consume msg %s failed: %v", id, err)
}
}
return
}
ch := make(chan string, len(ids))
for _, id := range ids {
ch <- id
}
close(ch)
wg := sync.WaitGroup{}
concurrent := int(q.concurrent)
if concurrent > len(ids) { // too many goroutines is no use
concurrent = len(ids)
}
wg.Add(concurrent)
for i := 0; i < concurrent; i++ {
go func() {
defer wg.Done()
for id := range ch {
err := q.callback(id)
if err != nil {
q.logger.Printf("consume msg %s failed: %v", id, err)
}
}
}()
}
wg.Wait()
}
func (q *DelayQueue) ack(idStr string) error {
atomic.AddInt32(&q.fetchCount, -1)
err := q.redisCli.ZRem(q.unAckKey, []string{idStr})
if err != nil {
return fmt.Errorf("remove from unack failed: %v", err)
@@ -389,6 +358,7 @@ func (q *DelayQueue) ack(idStr string) error {
}
func (q *DelayQueue) nack(idStr string) error {
atomic.AddInt32(&q.fetchCount, -1)
// update retry time as now, unack2Retry will move it to retry immediately
err := q.redisCli.ZAdd(q.unAckKey, map[string]float64{
idStr: float64(time.Now().Unix()),
@@ -501,32 +471,55 @@ func (q *DelayQueue) garbageCollect() error {
return nil
}
func (q *DelayQueue) consume() error {
func (q *DelayQueue) beforeConsume() ([]string, error) {
// pending to ready
err := q.pending2Ready()
if err != nil {
return err
return nil, err
}
// consume
// ready2Unack
// prioritize new message consumption to avoid avalanches
ids := make([]string, 0, q.fetchLimit)
var fetchCount int32
for {
fetchCount = atomic.LoadInt32(&q.fetchCount)
if q.fetchLimit > 0 && fetchCount >= int32(q.fetchLimit) {
break
}
idStr, err := q.ready2Unack()
if err == NilErr { // consumed all
break
}
if err != nil {
return err
return nil, err
}
ids = append(ids, idStr)
if q.fetchLimit > 0 && len(ids) >= int(q.fetchLimit) {
break
atomic.AddInt32(&q.fetchCount, 1)
}
// retry2Unack
if fetchCount < int32(q.fetchLimit) || q.fetchLimit == 0 {
for {
fetchCount = atomic.LoadInt32(&q.fetchCount)
if q.fetchLimit > 0 && fetchCount >= int32(q.fetchLimit) {
break
}
idStr, err := q.retry2Unack()
if err == NilErr { // consumed all
break
}
if err != nil {
return nil, err
}
ids = append(ids, idStr)
atomic.AddInt32(&q.fetchCount, 1)
}
}
if len(ids) > 0 {
q.batchCallback(ids)
}
return ids, nil
}
func (q *DelayQueue) afterConsume() error {
// unack to retry
err = q.unack2Retry()
err := q.unack2Retry()
if err != nil {
return err
}
@@ -534,27 +527,35 @@ func (q *DelayQueue) consume() error {
if err != nil {
return err
}
// retry
ids = make([]string, 0, q.fetchLimit)
for {
idStr, err := q.retry2Unack()
if err == NilErr { // consumed all
break
}
if err != nil {
return err
}
ids = append(ids, idStr)
if q.fetchLimit > 0 && len(ids) >= int(q.fetchLimit) {
break
}
}
if len(ids) > 0 {
q.batchCallback(ids)
}
return nil
}
func (q *DelayQueue) setRunning() {
atomic.StoreInt32(&q.running, 1)
}
func (q *DelayQueue) setNotRunning() {
atomic.StoreInt32(&q.running, 0)
}
func (q *DelayQueue) assertNotRunning() {
running := atomic.LoadInt32(&q.running)
if running > 0 {
panic("operation cannot be performed during running")
}
}
func (q *DelayQueue)goWithRecover(fn func()) {
go func () {
defer func () {
if err := recover(); err != nil {
q.logger.Printf("panic: %v\n", err)
}
}()
fn()
}()
}
// StartConsume creates a goroutine to consume message from DelayQueue
// use `<-done` to wait consumer stopping
// If there is no callback set, StartConsume will panic
@@ -563,17 +564,34 @@ func (q *DelayQueue) StartConsume() (done <-chan struct{}) {
panic("this instance has no callback")
}
q.close = make(chan struct{}, 1)
q.setRunning()
q.ticker = time.NewTicker(q.fetchInterval)
q.consumeBuffer = make(chan string, q.fetchLimit)
done0 := make(chan struct{})
// start worker
for i := 0; i < int(q.concurrent); i++ {
q.goWithRecover(func() {
for id := range q.consumeBuffer {
q.callback(id)
q.afterConsume()
}
})
}
// start main loop
go func() {
tickerLoop:
for {
select {
case <-q.ticker.C:
err := q.consume()
ids, err := q.beforeConsume()
if err != nil {
log.Printf("consume error: %v", err)
}
q.goWithRecover(func() {
for _, id := range ids {
q.consumeBuffer <- id
}
})
case <-q.close:
break tickerLoop
}
@@ -586,9 +604,11 @@ func (q *DelayQueue) StartConsume() (done <-chan struct{}) {
// StopConsume stops consumer goroutine
func (q *DelayQueue) StopConsume() {
close(q.close)
q.setNotRunning()
if q.ticker != nil {
q.ticker.Stop()
}
close(q.consumeBuffer)
}
// GetPendingCount returns the number of pending messages

View File

@@ -39,11 +39,15 @@ func TestDelayQueue_consume(t *testing.T) {
}
}
for i := 0; i < 10*size; i++ {
err := queue.consume()
ids, err := queue.beforeConsume()
if err != nil {
t.Errorf("consume error: %v", err)
return
}
for _, id := range ids {
queue.callback(id)
}
queue.afterConsume()
}
for k, v := range deliveryCount {
i, _ := strconv.ParseInt(k, 10, 64)
@@ -88,11 +92,15 @@ func TestDelayQueueOnCluster(t *testing.T) {
}
}
for i := 0; i < 10*size; i++ {
err := queue.consume()
ids, err := queue.beforeConsume()
if err != nil {
t.Errorf("consume error: %v", err)
return
}
for _, id := range ids {
queue.callback(id)
}
queue.afterConsume()
}
if succeed != size {
t.Error("msg not consumed")
@@ -127,11 +135,15 @@ func TestDelayQueue_ConcurrentConsume(t *testing.T) {
}
}
for i := 0; i < 2*size; i++ {
err := queue.consume()
ids, err := queue.beforeConsume()
if err != nil {
t.Errorf("consume error: %v", err)
return
}
for _, id := range ids {
queue.callback(id)
}
queue.afterConsume()
}
for k, v := range deliveryCount {
if v != 1 {
@@ -266,3 +278,66 @@ func TestDelayQueue_Massive_Backlog(t *testing.T) {
return
}
}
// consume should stopped after actual fetch count hits fetch limit
func TestDelayQueue_FetchLimit(t *testing.T) {
redisCli := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
})
redisCli.FlushDB(context.Background())
fetchLimit := 10
cb := func(s string) bool {
return true
}
queue := NewQueue("test", redisCli, UseHashTagKey()).
WithCallback(cb).
WithFetchInterval(time.Millisecond * 50).
WithMaxConsumeDuration(0).
WithLogger(log.New(os.Stderr, "[DelayQueue]", log.LstdFlags)).
WithFetchLimit(uint(fetchLimit))
for i := 0; i < fetchLimit; i++ {
err := queue.SendDelayMsg(strconv.Itoa(i), 0, WithMsgTTL(time.Hour))
if err != nil {
t.Error(err)
}
}
// fetch but not consume
ids1, err := queue.beforeConsume()
if err != nil {
t.Errorf("consume error: %v", err)
return
}
// send new messages
for i := 0; i < fetchLimit; i++ {
err := queue.SendDelayMsg(strconv.Itoa(i), 0, WithMsgTTL(time.Hour))
if err != nil {
t.Error(err)
}
}
ids2, err := queue.beforeConsume()
if err != nil {
t.Errorf("consume error: %v", err)
return
}
if len(ids2) > 0 {
t.Error("should get 0 message, after hitting fetch limit")
}
// consume
for _, id := range ids1 {
queue.callback(id)
}
queue.afterConsume()
// resume
ids3, err := queue.beforeConsume()
if err != nil {
t.Errorf("consume error: %v", err)
return
}
if len(ids3) == 0 {
t.Error("should get some messages, after consumption")
}
}

View File

@@ -59,7 +59,7 @@ func TestMonitor_get_status(t *testing.T) {
// test processing count
for i := 0; i < size/2; i++ {
_ , _ = queue.ready2Unack()
_, _ = queue.ready2Unack()
}
processing, err := monitor.GetProcessingCount()
if err != nil {
@@ -109,14 +109,14 @@ func TestMonitor_listener1(t *testing.T) {
monitor := NewMonitor("test", redisCli)
profile := &MyProfiler{}
monitor.ListenEvent(profile)
for i := 0; i < size; i++ {
err := queue.SendDelayMsg(strconv.Itoa(i), 0)
if err != nil {
t.Error(err)
}
}
queue.consume()
queue.beforeConsume()
if profile.ProduceCount != size {
t.Error("wrong produce count")
@@ -143,7 +143,7 @@ func TestMonitor_listener2(t *testing.T) {
monitor := NewMonitor("test", redisCli)
profile := &MyProfiler{}
monitor.ListenEvent(profile)
for i := 0; i < size; i++ {
err := queue.SendDelayMsg(strconv.Itoa(i), 0)
if err != nil {
@@ -151,7 +151,7 @@ func TestMonitor_listener2(t *testing.T) {
}
}
for i := 0; i < 3; i++ {
queue.consume()
queue.beforeConsume()
}
if profile.RetryCount != size {
@@ -160,4 +160,4 @@ func TestMonitor_listener2(t *testing.T) {
if profile.FailCount != size {
t.Error("wrong consume count")
}
}
}

View File

@@ -35,11 +35,15 @@ func TestPublisher(t *testing.T) {
}
}
for i := 0; i < 10*size; i++ {
err := queue.consume()
ids, err := queue.beforeConsume()
if err != nil {
t.Errorf("consume error: %v", err)
return
}
for _, id := range ids {
queue.callback(id)
}
queue.afterConsume()
}
for k, v := range deliveryCount {
i, _ := strconv.ParseInt(k, 10, 64)
@@ -53,4 +57,4 @@ func TestPublisher(t *testing.T) {
}
}
}
}
}