feat: [wip] - implement storage

This commit is contained in:
sujit
2024-10-17 10:03:07 +05:45
parent 53b68572dd
commit ea266be846
13 changed files with 730 additions and 482 deletions

1
.gitignore vendored
View File

@@ -21,3 +21,4 @@
go.work
.idea
.DS_Store
*.svg

View File

@@ -37,13 +37,13 @@ type Consumer struct {
queue string
}
func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer {
func NewConsumer(id string, queue string, handler Processor, opts ...Option) *Consumer {
options := SetupOptions(opts...)
return &Consumer{
id: id,
opts: options,
queue: queue,
handler: handler,
handler: handler.ProcessTask,
}
}
@@ -140,7 +140,7 @@ func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn
return
}
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
if err := c.pool.AddTask(ctx, &task, 1); err != nil {
if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil {
c.sendDenyMessage(ctx, task.ID, msg.Queue, err)
return
}
@@ -230,7 +230,7 @@ func (c *Consumer) Consume(ctx context.Context) error {
if err != nil {
return err
}
c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse)
c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.opts.storage)
if err := c.subscribe(ctx, c.queue); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
}

1
ctx.go
View File

@@ -20,6 +20,7 @@ import (
type Task struct {
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
Expiry time.Time `json:"expiry"`
Error error `json:"error"`
ID string `json:"id"`
Topic string `json:"topic"`

View File

@@ -134,7 +134,7 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG {
d.server = mq.NewBroker(opts...)
d.opts = opts
options := d.server.Options()
d.pool = mq.NewPool(options.NumOfWorkers(), options.QueueSize(), options.MaxMemoryLoad(), d.ProcessTask, callback)
d.pool = mq.NewPool(options.NumOfWorkers(), options.QueueSize(), options.MaxMemoryLoad(), d.ProcessTask, callback, options.Storage())
d.pool.Start(d.server.Options().NumOfWorkers())
go d.listenForTaskCleanup()
return d
@@ -349,7 +349,7 @@ func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result {
if ok {
ctxx = mq.SetHeaders(ctxx, headers.AsMap())
}
if err := tm.pool.AddTask(ctxx, task, 0); err != nil {
if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil {
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "FAILED", Error: err}
}
return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"}

View File

@@ -9,6 +9,6 @@ import (
)
func main() {
consumer1 := mq.NewConsumer("F", "queue1", tasks.Node6, mq.WithWorkerPool(100, 4, 50000))
consumer1 := mq.NewConsumer("F", "queue1", &tasks.Node6{}, mq.WithWorkerPool(100, 4, 50000))
consumer1.Consume(context.Background())
}

View File

@@ -9,12 +9,12 @@ import (
)
func main() {
pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback)
pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback, mq.NewMemoryTaskStorage(10*time.Minute))
time.Sleep(time.Millisecond)
pool.AddTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1)
pool.AddTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5)
pool.AddTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10)
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1)
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5)
pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10)
time.Sleep(5 * time.Second)
pool.PrintMetrics()

View File

@@ -11,14 +11,14 @@ import (
func main() {
handler := tasks.SchedulerHandler
callback := tasks.SchedulerCallback
pool := mq.NewPool(3, 5, 1000, handler, callback)
pool := mq.NewPool(3, 5, 1000, handler, callback, mq.NewMemoryTaskStorage(10*time.Minute))
ctx := context.Background()
pool.AddTask(context.Background(), &mq.Task{ID: "Task 1"}, 1)
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Task 1"}, 1)
time.Sleep(1 * time.Second)
pool.AddTask(context.Background(), &mq.Task{ID: "Task 2"}, 5)
pool.Scheduler.AddTask(ctx, &mq.Task{ID: "Every Minute Task"})
pool.EnqueueTask(context.Background(), &mq.Task{ID: "Task 2"}, 5)
pool.Scheduler().AddTask(ctx, &mq.Task{ID: "Every Minute Task"})
time.Sleep(10 * time.Minute)
pool.Scheduler.RemoveTask("Every Minute Task")
pool.Scheduler().RemoveTask("Every Minute Task")
time.Sleep(5 * time.Minute)
pool.PrintMetrics()
pool.Stop()

4
go.mod
View File

@@ -1,8 +1,6 @@
module github.com/oarkflow/mq
go 1.22.3
toolchain go1.23.0
go 1.23
require (
github.com/oarkflow/date v0.0.4

View File

@@ -72,6 +72,7 @@ type Options struct {
callback []func(context.Context, Result) Result
maxRetries int
initialDelay time.Duration
storage TaskStorage
maxBackoff time.Duration
jitterPercent float64
queueSize int
@@ -90,6 +91,10 @@ func (o *Options) NumOfWorkers() int {
return o.numOfWorkers
}
func (o *Options) Storage() TaskStorage {
return o.storage
}
func (o *Options) QueueSize() int {
return o.queueSize
}
@@ -109,6 +114,7 @@ func defaultOptions() *Options {
queueSize: 100,
numOfWorkers: runtime.NumCPU(),
maxMemoryLoad: 5000000,
storage: NewMemoryTaskStorage(10 * time.Minute),
}
}

500
pool.go
View File

@@ -4,440 +4,16 @@ import (
"container/heap"
"context"
"fmt"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/oarkflow/mq/utils"
)
type QueueTask struct {
ctx context.Context
payload *Task
priority int
}
type PriorityQueue []*QueueTask
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
return pq[i].priority > pq[j].priority
}
func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
func (pq *PriorityQueue) Push(x interface{}) {
item := x.(*QueueTask)
*pq = append(*pq, item)
}
func (pq *PriorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
*pq = old[0 : n-1]
return item
}
type Callback func(ctx context.Context, result Result) error
type ScheduleOptions struct {
Handler Handler
Callback Callback
Overlap bool
Interval time.Duration
Recurring bool
}
type SchedulerOption func(*ScheduleOptions)
// Helper functions to create SchedulerOptions
func WithSchedulerHandler(handler Handler) SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Handler = handler
}
}
func WithSchedulerCallback(callback Callback) SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Callback = callback
}
}
func WithOverlap() SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Overlap = true
}
}
func WithInterval(interval time.Duration) SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Interval = interval
}
}
func WithRecurring() SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Recurring = true
}
}
// defaultOptions returns the default scheduling options
func defaultSchedulerOptions() *ScheduleOptions {
return &ScheduleOptions{
Interval: time.Minute,
Recurring: true,
}
}
type SchedulerConfig struct {
Callback Callback
Overlap bool
}
type ScheduledTask struct {
ctx context.Context
handler Handler
payload *Task
config SchedulerConfig
schedule *Schedule
stop chan struct{}
executionHistory []ExecutionHistory
}
type Schedule struct {
Interval time.Duration
DayOfWeek []time.Weekday
DayOfMonth []int
TimeOfDay time.Time
Recurring bool
CronSpec string
}
func (s *Schedule) ToHumanReadable() string {
var sb strings.Builder
if s.CronSpec != "" {
cronDescription, err := parseCronSpec(s.CronSpec)
if err != nil {
sb.WriteString(fmt.Sprintf("Invalid CRON spec: %s\n", err.Error()))
} else {
sb.WriteString(fmt.Sprintf("CRON-based schedule: %s\n", cronDescription))
}
}
if s.Interval > 0 {
sb.WriteString(fmt.Sprintf("Recurring every %s\n", s.Interval))
}
if len(s.DayOfMonth) > 0 {
sb.WriteString("Occurs on the following days of the month: ")
for i, day := range s.DayOfMonth {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%d", day))
}
sb.WriteString("\n")
}
if len(s.DayOfWeek) > 0 {
sb.WriteString("Occurs on the following days of the week: ")
for i, day := range s.DayOfWeek {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(day.String())
}
sb.WriteString("\n")
}
if !s.TimeOfDay.IsZero() {
sb.WriteString(fmt.Sprintf("Time of day: %s\n", s.TimeOfDay.Format("15:04")))
}
if s.Recurring {
sb.WriteString("This schedule is recurring.\n")
} else {
sb.WriteString("This schedule is one-time.\n")
}
if sb.Len() == 0 {
sb.WriteString("No schedule defined.")
}
return sb.String()
}
type CronSchedule struct {
Minute string
Hour string
DayOfMonth string
Month string
DayOfWeek string
}
func (c CronSchedule) String() string {
return fmt.Sprintf("At %s minute(s) past %s, on %s, during %s, every %s", c.Minute, c.Hour, c.DayOfWeek, c.Month, c.DayOfMonth)
}
func parseCronSpec(cronSpec string) (CronSchedule, error) {
parts := strings.Fields(cronSpec)
if len(parts) != 5 {
return CronSchedule{}, fmt.Errorf("invalid CRON spec: expected 5 fields, got %d", len(parts))
}
minute, err := cronFieldToString(parts[0], "minute")
if err != nil {
return CronSchedule{}, err
}
hour, err := cronFieldToString(parts[1], "hour")
if err != nil {
return CronSchedule{}, err
}
dayOfMonth, err := cronFieldToString(parts[2], "day of the month")
if err != nil {
return CronSchedule{}, err
}
month, err := cronFieldToString(parts[3], "month")
if err != nil {
return CronSchedule{}, err
}
dayOfWeek, err := cronFieldToString(parts[4], "day of the week")
if err != nil {
return CronSchedule{}, err
}
return CronSchedule{
Minute: minute,
Hour: hour,
DayOfMonth: dayOfMonth,
Month: month,
DayOfWeek: dayOfWeek,
}, nil
}
func cronFieldToString(field string, fieldName string) (string, error) {
switch field {
case "*":
return fmt.Sprintf("every %s", fieldName), nil
default:
values, err := parseCronValue(field)
if err != nil {
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
}
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
}
}
func parseCronValue(field string) ([]string, error) {
var values []string
ranges := strings.Split(field, ",")
for _, r := range ranges {
if strings.Contains(r, "-") {
bounds := strings.Split(r, "-")
if len(bounds) != 2 {
return nil, fmt.Errorf("invalid range: %s", r)
}
start, err := strconv.Atoi(bounds[0])
if err != nil {
return nil, err
}
end, err := strconv.Atoi(bounds[1])
if err != nil {
return nil, err
}
for i := start; i <= end; i++ {
values = append(values, strconv.Itoa(i))
}
} else {
values = append(values, r)
}
}
return values, nil
}
type Scheduler struct {
tasks []ScheduledTask
mu sync.Mutex
pool *Pool
}
func (s *Scheduler) Start() {
for _, task := range s.tasks {
go s.schedule(task)
}
}
func (s *Scheduler) schedule(task ScheduledTask) {
if task.schedule.Interval > 0 {
ticker := time.NewTicker(task.schedule.Interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.executeTask(task)
case <-task.stop:
return
}
}
} else if task.schedule.Recurring {
for {
now := time.Now()
nextRun := task.getNextRunTime(now)
if nextRun.After(now) {
time.Sleep(nextRun.Sub(now))
}
s.executeTask(task)
}
}
}
func NewScheduler(pool *Pool) *Scheduler {
return &Scheduler{pool: pool}
}
func (task ScheduledTask) getNextRunTime(now time.Time) time.Time {
if task.schedule.CronSpec != "" {
return task.getNextCronRunTime(now)
}
if len(task.schedule.DayOfMonth) > 0 {
for _, day := range task.schedule.DayOfMonth {
nextRun := time.Date(now.Year(), now.Month(), day, task.schedule.TimeOfDay.Hour(), task.schedule.TimeOfDay.Minute(), 0, 0, now.Location())
if nextRun.After(now) {
return nextRun
}
}
nextMonth := now.AddDate(0, 1, 0)
return time.Date(nextMonth.Year(), nextMonth.Month(), task.schedule.DayOfMonth[0], task.schedule.TimeOfDay.Hour(), task.schedule.TimeOfDay.Minute(), 0, 0, now.Location())
}
if len(task.schedule.DayOfWeek) > 0 {
for _, weekday := range task.schedule.DayOfWeek {
nextRun := nextWeekday(now, weekday).Truncate(time.Minute).Add(task.schedule.TimeOfDay.Sub(time.Time{}))
if nextRun.After(now) {
return nextRun
}
}
}
return now
}
func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
cronSpecs, err := parseCronSpec(task.schedule.CronSpec)
if err != nil {
fmt.Println(fmt.Sprintf("Invalid CRON spec format: %s", err.Error()))
return now
}
nextRun := now
nextRun = task.applyCronField(nextRun, cronSpecs.Minute, "minute")
nextRun = task.applyCronField(nextRun, cronSpecs.Hour, "hour")
nextRun = task.applyCronField(nextRun, cronSpecs.DayOfMonth, "day")
nextRun = task.applyCronField(nextRun, cronSpecs.Month, "month")
nextRun = task.applyCronField(nextRun, cronSpecs.DayOfWeek, "weekday")
return nextRun
}
func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time {
switch fieldSpec {
case "*":
return t
default:
value, _ := strconv.Atoi(fieldSpec)
switch unit {
case "minute":
if t.Minute() > value {
t = t.Add(time.Hour)
}
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), value, 0, 0, t.Location())
case "hour":
if t.Hour() > value {
t = t.AddDate(0, 0, 1)
}
t = time.Date(t.Year(), t.Month(), t.Day(), value, t.Minute(), 0, 0, t.Location())
case "day":
if t.Day() > value {
t = t.AddDate(0, 1, 0)
}
t = time.Date(t.Year(), t.Month(), value, t.Hour(), t.Minute(), 0, 0, t.Location())
case "month":
if int(t.Month()) > value {
t = t.AddDate(1, 0, 0)
}
t = time.Date(t.Year(), time.Month(value), t.Day(), t.Hour(), t.Minute(), 0, 0, t.Location())
case "weekday":
weekday := time.Weekday(value)
for t.Weekday() != weekday {
t = t.AddDate(0, 0, 1)
}
}
return t
}
}
func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
daysUntil := (int(weekday) - int(t.Weekday()) + 7) % 7
if daysUntil == 0 {
daysUntil = 7
}
return t.AddDate(0, 0, daysUntil)
}
func (s *Scheduler) executeTask(task ScheduledTask) {
if task.config.Overlap || len(task.schedule.DayOfWeek) == 0 {
go func() {
result := task.handler(task.ctx, task.payload)
task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result})
if task.config.Callback != nil {
_ = task.config.Callback(task.ctx, result)
}
fmt.Printf("Executed scheduled task: %s\n", task.payload.ID)
}()
}
}
func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) {
s.mu.Lock()
defer s.mu.Unlock()
// Create a default options instance
options := defaultSchedulerOptions()
for _, opt := range opts {
opt(options)
}
if options.Handler == nil {
options.Handler = s.pool.handler
}
if options.Callback == nil {
options.Callback = s.pool.callback
}
stop := make(chan struct{})
// Create a new ScheduledTask using the provided options
s.tasks = append(s.tasks, ScheduledTask{
ctx: ctx,
handler: options.Handler,
payload: payload,
stop: stop,
config: SchedulerConfig{
Callback: options.Callback,
Overlap: options.Overlap,
},
schedule: &Schedule{
Interval: options.Interval,
Recurring: options.Recurring,
},
})
// Start scheduling the task
go s.schedule(s.tasks[len(s.tasks)-1])
}
func (s *Scheduler) RemoveTask(payloadID string) {
s.mu.Lock()
defer s.mu.Unlock()
for i, task := range s.tasks {
if task.payload.ID == payloadID {
close(task.stop)
s.tasks = append(s.tasks[:i], s.tasks[i+1:]...)
break
}
}
}
type Pool struct {
taskStorage TaskStorage
taskQueue PriorityQueue
taskQueueLock sync.Mutex
stop chan struct{}
@@ -452,7 +28,7 @@ type Pool struct {
totalTasks int
numOfWorkers int32
paused bool
Scheduler *Scheduler
scheduler *Scheduler
totalScheduledTasks int
}
@@ -460,7 +36,8 @@ func NewPool(
numOfWorkers, taskQueueSize int,
maxMemoryLoad int64,
handler Handler,
callback Callback) *Pool {
callback Callback,
storage TaskStorage) *Pool {
pool := &Pool{
taskQueue: make(PriorityQueue, 0, taskQueueSize),
stop: make(chan struct{}),
@@ -468,16 +45,25 @@ func NewPool(
maxMemoryLoad: maxMemoryLoad,
handler: handler,
callback: callback,
taskStorage: storage,
workerAdjust: make(chan int),
}
pool.Scheduler = NewScheduler(pool)
pool.scheduler = NewScheduler(pool)
heap.Init(&pool.taskQueue)
pool.Scheduler.Start()
pool.scheduler.Start()
pool.Start(numOfWorkers)
return pool
}
func (wp *Pool) Start(numWorkers int) {
storedTasks, err := wp.taskStorage.GetAllTasks()
if err == nil {
wp.taskQueueLock.Lock()
for _, task := range storedTasks {
heap.Push(&wp.taskQueue, task)
}
wp.taskQueueLock.Unlock()
}
for i := 0; i < numWorkers; i++ {
wp.wg.Add(1)
go wp.worker()
@@ -492,9 +78,20 @@ func (wp *Pool) worker() {
select {
case <-wp.taskNotify:
wp.taskQueueLock.Lock()
var task *QueueTask
if len(wp.taskQueue) > 0 && !wp.paused {
task := heap.Pop(&wp.taskQueue).(*QueueTask)
task = heap.Pop(&wp.taskQueue).(*QueueTask)
}
wp.taskQueueLock.Unlock()
if task == nil && !wp.paused {
var err error
task, err = wp.taskStorage.FetchNextTask()
if err != nil {
continue
}
}
if task != nil {
taskSize := int64(utils.SizeOf(task.payload))
wp.totalMemoryUsed += taskSize
wp.totalTasks++
@@ -508,10 +105,11 @@ func (wp *Pool) worker() {
if err := wp.callback(task.ctx, result); err != nil {
wp.errorCount++
}
}
if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil {
}
wp.totalMemoryUsed -= taskSize
} else {
wp.taskQueueLock.Unlock()
}
case <-wp.stop:
return
@@ -549,13 +147,16 @@ func (wp *Pool) adjustWorkers(newWorkerCount int) {
atomic.StoreInt32(&wp.numOfWorkers, int32(newWorkerCount))
}
func (wp *Pool) AddTask(ctx context.Context, payload *Task, priority int) error {
func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) error {
if payload.ID == "" {
payload.ID = NewID()
}
task := &QueueTask{ctx: ctx, payload: payload, priority: priority}
if err := wp.taskStorage.SaveTask(task); err != nil {
return err
}
wp.taskQueueLock.Lock()
defer wp.taskQueueLock.Unlock()
task := &QueueTask{ctx: ctx, payload: payload, priority: priority}
taskSize := int64(utils.SizeOf(payload))
if wp.totalMemoryUsed+taskSize > wp.maxMemoryLoad && wp.maxMemoryLoad > 0 {
return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize)
@@ -590,34 +191,9 @@ func (wp *Pool) AdjustWorkerCount(newWorkerCount int) {
func (wp *Pool) PrintMetrics() {
fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes, Total Scheduled Tasks: %d\n",
wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed, len(wp.Scheduler.tasks))
wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed, len(wp.scheduler.tasks))
}
type ExecutionHistory struct {
Timestamp time.Time
Result Result
}
func (s *Scheduler) PrintAllTasks() {
s.mu.Lock()
defer s.mu.Unlock()
fmt.Println("Scheduled Tasks:")
for _, task := range s.tasks {
fmt.Printf("Task ID: %s, Next Execution: %s\n", task.payload.ID, task.getNextRunTime(time.Now()))
}
}
func (s *Scheduler) PrintExecutionHistory(taskID string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, task := range s.tasks {
if task.payload.ID == taskID {
fmt.Printf("Execution History for Task ID: %s\n", taskID)
for _, history := range task.executionHistory {
fmt.Printf("Timestamp: %s, Result: %v\n", history.Timestamp, history.Result.Error)
}
return
}
}
fmt.Printf("No task found with ID: %s\n", taskID)
func (wp *Pool) Scheduler() *Scheduler {
return wp.scheduler
}

View File

@@ -1,6 +1,8 @@
package mq
import (
"context"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
)
@@ -29,3 +31,32 @@ func (b *Broker) NewQueue(qName string) *Queue {
go b.dispatchWorker(q)
return q
}
type QueueTask struct {
ctx context.Context
payload *Task
priority int
}
type PriorityQueue []*QueueTask
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
return pq[i].priority > pq[j].priority
}
func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }
func (pq *PriorityQueue) Push(x interface{}) {
item := x.(*QueueTask)
*pq = append(*pq, item)
}
func (pq *PriorityQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
*pq = old[0 : n-1]
return item
}

432
scheduler.go Normal file
View File

@@ -0,0 +1,432 @@
package mq
import (
"context"
"fmt"
"strconv"
"strings"
"sync"
"time"
)
type ScheduleOptions struct {
Handler Handler
Callback Callback
Overlap bool
Interval time.Duration
Recurring bool
}
type SchedulerOption func(*ScheduleOptions)
// Helper functions to create SchedulerOptions
func WithSchedulerHandler(handler Handler) SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Handler = handler
}
}
func WithSchedulerCallback(callback Callback) SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Callback = callback
}
}
func WithOverlap() SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Overlap = true
}
}
func WithInterval(interval time.Duration) SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Interval = interval
}
}
func WithRecurring() SchedulerOption {
return func(opts *ScheduleOptions) {
opts.Recurring = true
}
}
// defaultOptions returns the default scheduling options
func defaultSchedulerOptions() *ScheduleOptions {
return &ScheduleOptions{
Interval: time.Minute,
Recurring: true,
}
}
type SchedulerConfig struct {
Callback Callback
Overlap bool
}
type ScheduledTask struct {
ctx context.Context
handler Handler
payload *Task
config SchedulerConfig
schedule *Schedule
stop chan struct{}
executionHistory []ExecutionHistory
}
type Schedule struct {
Interval time.Duration
DayOfWeek []time.Weekday
DayOfMonth []int
TimeOfDay time.Time
Recurring bool
CronSpec string
}
func (s *Schedule) ToHumanReadable() string {
var sb strings.Builder
if s.CronSpec != "" {
cronDescription, err := parseCronSpec(s.CronSpec)
if err != nil {
sb.WriteString(fmt.Sprintf("Invalid CRON spec: %s\n", err.Error()))
} else {
sb.WriteString(fmt.Sprintf("CRON-based schedule: %s\n", cronDescription))
}
}
if s.Interval > 0 {
sb.WriteString(fmt.Sprintf("Recurring every %s\n", s.Interval))
}
if len(s.DayOfMonth) > 0 {
sb.WriteString("Occurs on the following days of the month: ")
for i, day := range s.DayOfMonth {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%d", day))
}
sb.WriteString("\n")
}
if len(s.DayOfWeek) > 0 {
sb.WriteString("Occurs on the following days of the week: ")
for i, day := range s.DayOfWeek {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(day.String())
}
sb.WriteString("\n")
}
if !s.TimeOfDay.IsZero() {
sb.WriteString(fmt.Sprintf("Time of day: %s\n", s.TimeOfDay.Format("15:04")))
}
if s.Recurring {
sb.WriteString("This schedule is recurring.\n")
} else {
sb.WriteString("This schedule is one-time.\n")
}
if sb.Len() == 0 {
sb.WriteString("No schedule defined.")
}
return sb.String()
}
type CronSchedule struct {
Minute string
Hour string
DayOfMonth string
Month string
DayOfWeek string
}
func (c CronSchedule) String() string {
return fmt.Sprintf("At %s minute(s) past %s, on %s, during %s, every %s", c.Minute, c.Hour, c.DayOfWeek, c.Month, c.DayOfMonth)
}
func parseCronSpec(cronSpec string) (CronSchedule, error) {
parts := strings.Fields(cronSpec)
if len(parts) != 5 {
return CronSchedule{}, fmt.Errorf("invalid CRON spec: expected 5 fields, got %d", len(parts))
}
minute, err := cronFieldToString(parts[0], "minute")
if err != nil {
return CronSchedule{}, err
}
hour, err := cronFieldToString(parts[1], "hour")
if err != nil {
return CronSchedule{}, err
}
dayOfMonth, err := cronFieldToString(parts[2], "day of the month")
if err != nil {
return CronSchedule{}, err
}
month, err := cronFieldToString(parts[3], "month")
if err != nil {
return CronSchedule{}, err
}
dayOfWeek, err := cronFieldToString(parts[4], "day of the week")
if err != nil {
return CronSchedule{}, err
}
return CronSchedule{
Minute: minute,
Hour: hour,
DayOfMonth: dayOfMonth,
Month: month,
DayOfWeek: dayOfWeek,
}, nil
}
func cronFieldToString(field string, fieldName string) (string, error) {
switch field {
case "*":
return fmt.Sprintf("every %s", fieldName), nil
default:
values, err := parseCronValue(field)
if err != nil {
return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error())
}
return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil
}
}
func parseCronValue(field string) ([]string, error) {
var values []string
ranges := strings.Split(field, ",")
for _, r := range ranges {
if strings.Contains(r, "-") {
bounds := strings.Split(r, "-")
if len(bounds) != 2 {
return nil, fmt.Errorf("invalid range: %s", r)
}
start, err := strconv.Atoi(bounds[0])
if err != nil {
return nil, err
}
end, err := strconv.Atoi(bounds[1])
if err != nil {
return nil, err
}
for i := start; i <= end; i++ {
values = append(values, strconv.Itoa(i))
}
} else {
values = append(values, r)
}
}
return values, nil
}
type Scheduler struct {
tasks []ScheduledTask
mu sync.Mutex
pool *Pool
}
func (s *Scheduler) Start() {
for _, task := range s.tasks {
go s.schedule(task)
}
}
func (s *Scheduler) schedule(task ScheduledTask) {
if task.schedule.Interval > 0 {
ticker := time.NewTicker(task.schedule.Interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.executeTask(task)
case <-task.stop:
return
}
}
} else if task.schedule.Recurring {
for {
now := time.Now()
nextRun := task.getNextRunTime(now)
if nextRun.After(now) {
time.Sleep(nextRun.Sub(now))
}
s.executeTask(task)
}
}
}
func NewScheduler(pool *Pool) *Scheduler {
return &Scheduler{pool: pool}
}
func (task ScheduledTask) getNextRunTime(now time.Time) time.Time {
if task.schedule.CronSpec != "" {
return task.getNextCronRunTime(now)
}
if len(task.schedule.DayOfMonth) > 0 {
for _, day := range task.schedule.DayOfMonth {
nextRun := time.Date(now.Year(), now.Month(), day, task.schedule.TimeOfDay.Hour(), task.schedule.TimeOfDay.Minute(), 0, 0, now.Location())
if nextRun.After(now) {
return nextRun
}
}
nextMonth := now.AddDate(0, 1, 0)
return time.Date(nextMonth.Year(), nextMonth.Month(), task.schedule.DayOfMonth[0], task.schedule.TimeOfDay.Hour(), task.schedule.TimeOfDay.Minute(), 0, 0, now.Location())
}
if len(task.schedule.DayOfWeek) > 0 {
for _, weekday := range task.schedule.DayOfWeek {
nextRun := nextWeekday(now, weekday).Truncate(time.Minute).Add(task.schedule.TimeOfDay.Sub(time.Time{}))
if nextRun.After(now) {
return nextRun
}
}
}
return now
}
func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time {
cronSpecs, err := parseCronSpec(task.schedule.CronSpec)
if err != nil {
fmt.Println(fmt.Sprintf("Invalid CRON spec format: %s", err.Error()))
return now
}
nextRun := now
nextRun = task.applyCronField(nextRun, cronSpecs.Minute, "minute")
nextRun = task.applyCronField(nextRun, cronSpecs.Hour, "hour")
nextRun = task.applyCronField(nextRun, cronSpecs.DayOfMonth, "day")
nextRun = task.applyCronField(nextRun, cronSpecs.Month, "month")
nextRun = task.applyCronField(nextRun, cronSpecs.DayOfWeek, "weekday")
return nextRun
}
func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time {
switch fieldSpec {
case "*":
return t
default:
value, _ := strconv.Atoi(fieldSpec)
switch unit {
case "minute":
if t.Minute() > value {
t = t.Add(time.Hour)
}
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), value, 0, 0, t.Location())
case "hour":
if t.Hour() > value {
t = t.AddDate(0, 0, 1)
}
t = time.Date(t.Year(), t.Month(), t.Day(), value, t.Minute(), 0, 0, t.Location())
case "day":
if t.Day() > value {
t = t.AddDate(0, 1, 0)
}
t = time.Date(t.Year(), t.Month(), value, t.Hour(), t.Minute(), 0, 0, t.Location())
case "month":
if int(t.Month()) > value {
t = t.AddDate(1, 0, 0)
}
t = time.Date(t.Year(), time.Month(value), t.Day(), t.Hour(), t.Minute(), 0, 0, t.Location())
case "weekday":
weekday := time.Weekday(value)
for t.Weekday() != weekday {
t = t.AddDate(0, 0, 1)
}
}
return t
}
}
func nextWeekday(t time.Time, weekday time.Weekday) time.Time {
daysUntil := (int(weekday) - int(t.Weekday()) + 7) % 7
if daysUntil == 0 {
daysUntil = 7
}
return t.AddDate(0, 0, daysUntil)
}
func (s *Scheduler) executeTask(task ScheduledTask) {
if task.config.Overlap || len(task.schedule.DayOfWeek) == 0 {
go func() {
result := task.handler(task.ctx, task.payload)
task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result})
if task.config.Callback != nil {
_ = task.config.Callback(task.ctx, result)
}
fmt.Printf("Executed scheduled task: %s\n", task.payload.ID)
}()
}
}
func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) {
s.mu.Lock()
defer s.mu.Unlock()
// Create a default options instance
options := defaultSchedulerOptions()
for _, opt := range opts {
opt(options)
}
if options.Handler == nil {
options.Handler = s.pool.handler
}
if options.Callback == nil {
options.Callback = s.pool.callback
}
stop := make(chan struct{})
// Create a new ScheduledTask using the provided options
s.tasks = append(s.tasks, ScheduledTask{
ctx: ctx,
handler: options.Handler,
payload: payload,
stop: stop,
config: SchedulerConfig{
Callback: options.Callback,
Overlap: options.Overlap,
},
schedule: &Schedule{
Interval: options.Interval,
Recurring: options.Recurring,
},
})
// Start scheduling the task
go s.schedule(s.tasks[len(s.tasks)-1])
}
func (s *Scheduler) RemoveTask(payloadID string) {
s.mu.Lock()
defer s.mu.Unlock()
for i, task := range s.tasks {
if task.payload.ID == payloadID {
close(task.stop)
s.tasks = append(s.tasks[:i], s.tasks[i+1:]...)
break
}
}
}
type ExecutionHistory struct {
Timestamp time.Time
Result Result
}
func (s *Scheduler) PrintAllTasks() {
s.mu.Lock()
defer s.mu.Unlock()
fmt.Println("Scheduled Tasks:")
for _, task := range s.tasks {
fmt.Printf("Task ID: %s, Next Execution: %s\n", task.payload.ID, task.getNextRunTime(time.Now()))
}
}
func (s *Scheduler) PrintExecutionHistory(taskID string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, task := range s.tasks {
if task.payload.ID == taskID {
fmt.Printf("Execution History for Task ID: %s\n", taskID)
for _, history := range task.executionHistory {
fmt.Printf("Timestamp: %s, Result: %v\n", history.Timestamp, history.Result.Error)
}
return
}
}
fmt.Printf("No task found with ID: %s\n", taskID)
}

203
storage.go Normal file
View File

@@ -0,0 +1,203 @@
package mq
import (
"container/heap"
"fmt"
"sync"
"time"
)
type TaskStorage interface {
SaveTask(task *QueueTask) error
GetTask(taskID string) (*QueueTask, error)
DeleteTask(taskID string) error
GetAllTasks() ([]*QueueTask, error)
FetchNextTask() (*QueueTask, error)
CleanupExpiredTasks() error
}
type MemoryTaskStorage struct {
tasks PriorityQueue
taskLock sync.Mutex
expiryTime time.Duration
}
func NewMemoryTaskStorage(expiryTime time.Duration) *MemoryTaskStorage {
return &MemoryTaskStorage{
tasks: make(PriorityQueue, 0),
expiryTime: expiryTime,
}
}
func (m *MemoryTaskStorage) SaveTask(task *QueueTask) error {
m.taskLock.Lock()
defer m.taskLock.Unlock()
heap.Push(&m.tasks, task)
return nil
}
func (m *MemoryTaskStorage) GetTask(taskID string) (*QueueTask, error) {
m.taskLock.Lock()
defer m.taskLock.Unlock()
for _, task := range m.tasks {
if task.payload.ID == taskID {
return task, nil
}
}
return nil, fmt.Errorf("task not found")
}
func (m *MemoryTaskStorage) DeleteTask(taskID string) error {
m.taskLock.Lock()
defer m.taskLock.Unlock()
for i, task := range m.tasks {
if task.payload.ID == taskID {
heap.Remove(&m.tasks, i)
return nil
}
}
return fmt.Errorf("task not found")
}
func (m *MemoryTaskStorage) GetAllTasks() ([]*QueueTask, error) {
m.taskLock.Lock()
defer m.taskLock.Unlock()
tasks := make([]*QueueTask, len(m.tasks))
for i, task := range m.tasks {
tasks[i] = task
}
return tasks, nil
}
func (m *MemoryTaskStorage) FetchNextTask() (*QueueTask, error) {
m.taskLock.Lock()
defer m.taskLock.Unlock()
if len(m.tasks) == 0 {
return nil, fmt.Errorf("no tasks available")
}
task := heap.Pop(&m.tasks).(*QueueTask)
if task.payload.CreatedAt.Add(m.expiryTime).Before(time.Now()) {
m.DeleteTask(task.payload.ID)
return m.FetchNextTask()
}
return task, nil
}
func (m *MemoryTaskStorage) CleanupExpiredTasks() error {
m.taskLock.Lock()
defer m.taskLock.Unlock()
for i := 0; i < len(m.tasks); i++ {
task := m.tasks[i]
if task.payload.CreatedAt.Add(m.expiryTime).Before(time.Now()) {
heap.Remove(&m.tasks, i)
i-- // Adjust index after removal
}
}
return nil
}
/*
type PostgresTaskStorage struct {
db *sql.DB
}
func NewPostgresTaskStorage(db *sql.DB) *PostgresTaskStorage {
return &PostgresTaskStorage{db: db}
}
func (p *PostgresTaskStorage) SaveTask(task *QueueTask) error {
query := `
INSERT INTO tasks (id, payload, priority, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (id) DO NOTHING`
payloadBytes, err := utils.Serialize(task.payload) // Serialize converts the task to bytes
if err != nil {
return err
}
_, err = p.db.Exec(query, task.payload.ID, payloadBytes, task.priority, task.payload.CreatedAt, time.Now())
return err
}
func (p *PostgresTaskStorage) GetTask(taskID string) (*QueueTask, error) {
query := `SELECT id, payload, priority, created_at FROM tasks WHERE id = $1`
var task QueueTask
var payloadBytes []byte
err := p.db.QueryRow(query, taskID).Scan(&task.payload.ID, &payloadBytes, &task.priority, &task.payload.CreatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("task not found")
}
return nil, err
}
task.payload, err = utils.Deserialize(payloadBytes) // Deserialize converts bytes to Task object
if err != nil {
return nil, err
}
return &task, nil
}
func (p *PostgresTaskStorage) DeleteTask(taskID string) error {
query := `DELETE FROM tasks WHERE id = $1`
_, err := p.db.Exec(query, taskID)
return err
}
func (p *PostgresTaskStorage) GetAllTasks() ([]*QueueTask, error) {
query := `SELECT id, payload, priority, created_at FROM tasks`
rows, err := p.db.Query(query)
if err != nil {
return nil, err
}
defer rows.Close()
var tasks []*QueueTask
for rows.Next() {
var task QueueTask
var payloadBytes []byte
err := rows.Scan(&task.payload.ID, &payloadBytes, &task.priority, &task.payload.CreatedAt)
if err != nil {
return nil, err
}
task.payload, err = utils.Deserialize(payloadBytes)
if err != nil {
return nil, err
}
tasks = append(tasks, &task)
}
return tasks, nil
}
func (p *PostgresTaskStorage) FetchNextTask() (*QueueTask, error) {
query := `
SELECT id, payload, priority, created_at FROM tasks
ORDER BY priority DESC, created_at ASC
LIMIT 1`
var task QueueTask
var payloadBytes []byte
err := p.db.QueryRow(query).Scan(&task.payload.ID, &payloadBytes, &task.priority, &task.payload.CreatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("no tasks available")
}
return nil, err
}
task.payload, err = utils.Deserialize(payloadBytes)
if err != nil {
return nil, err
}
return &task, nil
}
func (p *PostgresTaskStorage) CleanupExpiredTasks() error {
query := `DELETE FROM tasks WHERE created_at < $1`
_, err := p.db.Exec(query, time.Now().Add(-time.Hour*24)) // Assuming tasks older than 24 hours are expired
return err
}
*/