diff --git a/.gitignore b/.gitignore index d09d700..d918e17 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ go.work .idea .DS_Store +*.svg \ No newline at end of file diff --git a/consumer.go b/consumer.go index 05a28e1..3e75863 100644 --- a/consumer.go +++ b/consumer.go @@ -37,13 +37,13 @@ type Consumer struct { queue string } -func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer { +func NewConsumer(id string, queue string, handler Processor, opts ...Option) *Consumer { options := SetupOptions(opts...) return &Consumer{ id: id, opts: options, queue: queue, - handler: handler, + handler: handler.ProcessTask, } } @@ -140,7 +140,7 @@ func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn return } ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue}) - if err := c.pool.AddTask(ctx, &task, 1); err != nil { + if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil { c.sendDenyMessage(ctx, task.ID, msg.Queue, err) return } @@ -230,7 +230,7 @@ func (c *Consumer) Consume(ctx context.Context) error { if err != nil { return err } - c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse) + c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.opts.storage) if err := c.subscribe(ctx, c.queue); err != nil { return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) } diff --git a/ctx.go b/ctx.go index 122bb09..2d16ed7 100644 --- a/ctx.go +++ b/ctx.go @@ -20,6 +20,7 @@ import ( type Task struct { CreatedAt time.Time `json:"created_at"` ProcessedAt time.Time `json:"processed_at"` + Expiry time.Time `json:"expiry"` Error error `json:"error"` ID string `json:"id"` Topic string `json:"topic"` diff --git a/dag/dag.go b/dag/dag.go index f0976da..da3de2f 100644 --- a/dag/dag.go +++ b/dag/dag.go @@ -134,7 +134,7 @@ func NewDAG(name, key string, opts ...mq.Option) *DAG { d.server = mq.NewBroker(opts...) d.opts = opts options := d.server.Options() - d.pool = mq.NewPool(options.NumOfWorkers(), options.QueueSize(), options.MaxMemoryLoad(), d.ProcessTask, callback) + d.pool = mq.NewPool(options.NumOfWorkers(), options.QueueSize(), options.MaxMemoryLoad(), d.ProcessTask, callback, options.Storage()) d.pool.Start(d.server.Options().NumOfWorkers()) go d.listenForTaskCleanup() return d @@ -349,7 +349,7 @@ func (tm *DAG) Process(ctx context.Context, payload []byte) mq.Result { if ok { ctxx = mq.SetHeaders(ctxx, headers.AsMap()) } - if err := tm.pool.AddTask(ctxx, task, 0); err != nil { + if err := tm.pool.EnqueueTask(ctxx, task, 0); err != nil { return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "FAILED", Error: err} } return mq.Result{CreatedAt: task.CreatedAt, TaskID: task.ID, Topic: initialNode, Status: "PENDING"} diff --git a/examples/consumer.go b/examples/consumer.go index 3b84207..84b7a6b 100644 --- a/examples/consumer.go +++ b/examples/consumer.go @@ -9,6 +9,6 @@ import ( ) func main() { - consumer1 := mq.NewConsumer("F", "queue1", tasks.Node6, mq.WithWorkerPool(100, 4, 50000)) + consumer1 := mq.NewConsumer("F", "queue1", &tasks.Node6{}, mq.WithWorkerPool(100, 4, 50000)) consumer1.Consume(context.Background()) } diff --git a/examples/priority.go b/examples/priority.go index c646f05..c97f862 100644 --- a/examples/priority.go +++ b/examples/priority.go @@ -9,12 +9,12 @@ import ( ) func main() { - pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback) + pool := mq.NewPool(2, 5, 1000, tasks.SchedulerHandler, tasks.SchedulerCallback, mq.NewMemoryTaskStorage(10*time.Minute)) time.Sleep(time.Millisecond) - pool.AddTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1) - pool.AddTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5) - pool.AddTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10) + pool.EnqueueTask(context.Background(), &mq.Task{ID: "Low Priority Task"}, 1) + pool.EnqueueTask(context.Background(), &mq.Task{ID: "Medium Priority Task"}, 5) + pool.EnqueueTask(context.Background(), &mq.Task{ID: "High Priority Task"}, 10) time.Sleep(5 * time.Second) pool.PrintMetrics() diff --git a/examples/scheduler.go b/examples/scheduler.go index cb0e1d0..f854286 100644 --- a/examples/scheduler.go +++ b/examples/scheduler.go @@ -11,14 +11,14 @@ import ( func main() { handler := tasks.SchedulerHandler callback := tasks.SchedulerCallback - pool := mq.NewPool(3, 5, 1000, handler, callback) + pool := mq.NewPool(3, 5, 1000, handler, callback, mq.NewMemoryTaskStorage(10*time.Minute)) ctx := context.Background() - pool.AddTask(context.Background(), &mq.Task{ID: "Task 1"}, 1) + pool.EnqueueTask(context.Background(), &mq.Task{ID: "Task 1"}, 1) time.Sleep(1 * time.Second) - pool.AddTask(context.Background(), &mq.Task{ID: "Task 2"}, 5) - pool.Scheduler.AddTask(ctx, &mq.Task{ID: "Every Minute Task"}) + pool.EnqueueTask(context.Background(), &mq.Task{ID: "Task 2"}, 5) + pool.Scheduler().AddTask(ctx, &mq.Task{ID: "Every Minute Task"}) time.Sleep(10 * time.Minute) - pool.Scheduler.RemoveTask("Every Minute Task") + pool.Scheduler().RemoveTask("Every Minute Task") time.Sleep(5 * time.Minute) pool.PrintMetrics() pool.Stop() diff --git a/go.mod b/go.mod index ee1b125..7b82534 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/oarkflow/mq -go 1.22.3 - -toolchain go1.23.0 +go 1.23 require ( github.com/oarkflow/date v0.0.4 diff --git a/options.go b/options.go index 7001d64..2f47c88 100644 --- a/options.go +++ b/options.go @@ -72,6 +72,7 @@ type Options struct { callback []func(context.Context, Result) Result maxRetries int initialDelay time.Duration + storage TaskStorage maxBackoff time.Duration jitterPercent float64 queueSize int @@ -90,6 +91,10 @@ func (o *Options) NumOfWorkers() int { return o.numOfWorkers } +func (o *Options) Storage() TaskStorage { + return o.storage +} + func (o *Options) QueueSize() int { return o.queueSize } @@ -109,6 +114,7 @@ func defaultOptions() *Options { queueSize: 100, numOfWorkers: runtime.NumCPU(), maxMemoryLoad: 5000000, + storage: NewMemoryTaskStorage(10 * time.Minute), } } diff --git a/pool.go b/pool.go index 0f8152e..f0e6c55 100644 --- a/pool.go +++ b/pool.go @@ -4,440 +4,16 @@ import ( "container/heap" "context" "fmt" - "strconv" - "strings" "sync" "sync/atomic" - "time" "github.com/oarkflow/mq/utils" ) -type QueueTask struct { - ctx context.Context - payload *Task - priority int -} - -type PriorityQueue []*QueueTask - -func (pq PriorityQueue) Len() int { return len(pq) } - -func (pq PriorityQueue) Less(i, j int) bool { - return pq[i].priority > pq[j].priority -} - -func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } - -func (pq *PriorityQueue) Push(x interface{}) { - item := x.(*QueueTask) - *pq = append(*pq, item) -} - -func (pq *PriorityQueue) Pop() interface{} { - old := *pq - n := len(old) - item := old[n-1] - *pq = old[0 : n-1] - return item -} - type Callback func(ctx context.Context, result Result) error -type ScheduleOptions struct { - Handler Handler - Callback Callback - Overlap bool - Interval time.Duration - Recurring bool -} - -type SchedulerOption func(*ScheduleOptions) - -// Helper functions to create SchedulerOptions -func WithSchedulerHandler(handler Handler) SchedulerOption { - return func(opts *ScheduleOptions) { - opts.Handler = handler - } -} - -func WithSchedulerCallback(callback Callback) SchedulerOption { - return func(opts *ScheduleOptions) { - opts.Callback = callback - } -} - -func WithOverlap() SchedulerOption { - return func(opts *ScheduleOptions) { - opts.Overlap = true - } -} - -func WithInterval(interval time.Duration) SchedulerOption { - return func(opts *ScheduleOptions) { - opts.Interval = interval - } -} - -func WithRecurring() SchedulerOption { - return func(opts *ScheduleOptions) { - opts.Recurring = true - } -} - -// defaultOptions returns the default scheduling options -func defaultSchedulerOptions() *ScheduleOptions { - return &ScheduleOptions{ - Interval: time.Minute, - Recurring: true, - } -} - -type SchedulerConfig struct { - Callback Callback - Overlap bool -} - -type ScheduledTask struct { - ctx context.Context - handler Handler - payload *Task - config SchedulerConfig - schedule *Schedule - stop chan struct{} - executionHistory []ExecutionHistory -} - -type Schedule struct { - Interval time.Duration - DayOfWeek []time.Weekday - DayOfMonth []int - TimeOfDay time.Time - Recurring bool - CronSpec string -} - -func (s *Schedule) ToHumanReadable() string { - var sb strings.Builder - if s.CronSpec != "" { - cronDescription, err := parseCronSpec(s.CronSpec) - if err != nil { - sb.WriteString(fmt.Sprintf("Invalid CRON spec: %s\n", err.Error())) - } else { - sb.WriteString(fmt.Sprintf("CRON-based schedule: %s\n", cronDescription)) - } - } - if s.Interval > 0 { - sb.WriteString(fmt.Sprintf("Recurring every %s\n", s.Interval)) - } - if len(s.DayOfMonth) > 0 { - sb.WriteString("Occurs on the following days of the month: ") - for i, day := range s.DayOfMonth { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(fmt.Sprintf("%d", day)) - } - sb.WriteString("\n") - } - if len(s.DayOfWeek) > 0 { - sb.WriteString("Occurs on the following days of the week: ") - for i, day := range s.DayOfWeek { - if i > 0 { - sb.WriteString(", ") - } - sb.WriteString(day.String()) - } - sb.WriteString("\n") - } - if !s.TimeOfDay.IsZero() { - sb.WriteString(fmt.Sprintf("Time of day: %s\n", s.TimeOfDay.Format("15:04"))) - } - if s.Recurring { - sb.WriteString("This schedule is recurring.\n") - } else { - sb.WriteString("This schedule is one-time.\n") - } - if sb.Len() == 0 { - sb.WriteString("No schedule defined.") - } - return sb.String() -} - -type CronSchedule struct { - Minute string - Hour string - DayOfMonth string - Month string - DayOfWeek string -} - -func (c CronSchedule) String() string { - return fmt.Sprintf("At %s minute(s) past %s, on %s, during %s, every %s", c.Minute, c.Hour, c.DayOfWeek, c.Month, c.DayOfMonth) -} - -func parseCronSpec(cronSpec string) (CronSchedule, error) { - parts := strings.Fields(cronSpec) - if len(parts) != 5 { - return CronSchedule{}, fmt.Errorf("invalid CRON spec: expected 5 fields, got %d", len(parts)) - } - minute, err := cronFieldToString(parts[0], "minute") - if err != nil { - return CronSchedule{}, err - } - hour, err := cronFieldToString(parts[1], "hour") - if err != nil { - return CronSchedule{}, err - } - dayOfMonth, err := cronFieldToString(parts[2], "day of the month") - if err != nil { - return CronSchedule{}, err - } - month, err := cronFieldToString(parts[3], "month") - if err != nil { - return CronSchedule{}, err - } - dayOfWeek, err := cronFieldToString(parts[4], "day of the week") - if err != nil { - return CronSchedule{}, err - } - return CronSchedule{ - Minute: minute, - Hour: hour, - DayOfMonth: dayOfMonth, - Month: month, - DayOfWeek: dayOfWeek, - }, nil -} - -func cronFieldToString(field string, fieldName string) (string, error) { - switch field { - case "*": - return fmt.Sprintf("every %s", fieldName), nil - default: - values, err := parseCronValue(field) - if err != nil { - return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error()) - } - return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil - } -} - -func parseCronValue(field string) ([]string, error) { - var values []string - ranges := strings.Split(field, ",") - for _, r := range ranges { - if strings.Contains(r, "-") { - bounds := strings.Split(r, "-") - if len(bounds) != 2 { - return nil, fmt.Errorf("invalid range: %s", r) - } - start, err := strconv.Atoi(bounds[0]) - if err != nil { - return nil, err - } - end, err := strconv.Atoi(bounds[1]) - if err != nil { - return nil, err - } - for i := start; i <= end; i++ { - values = append(values, strconv.Itoa(i)) - } - } else { - values = append(values, r) - } - } - return values, nil -} - -type Scheduler struct { - tasks []ScheduledTask - mu sync.Mutex - pool *Pool -} - -func (s *Scheduler) Start() { - for _, task := range s.tasks { - go s.schedule(task) - } -} - -func (s *Scheduler) schedule(task ScheduledTask) { - if task.schedule.Interval > 0 { - ticker := time.NewTicker(task.schedule.Interval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - s.executeTask(task) - case <-task.stop: - return - } - } - } else if task.schedule.Recurring { - for { - now := time.Now() - nextRun := task.getNextRunTime(now) - if nextRun.After(now) { - time.Sleep(nextRun.Sub(now)) - } - s.executeTask(task) - } - } -} - -func NewScheduler(pool *Pool) *Scheduler { - return &Scheduler{pool: pool} -} - -func (task ScheduledTask) getNextRunTime(now time.Time) time.Time { - if task.schedule.CronSpec != "" { - return task.getNextCronRunTime(now) - } - if len(task.schedule.DayOfMonth) > 0 { - for _, day := range task.schedule.DayOfMonth { - nextRun := time.Date(now.Year(), now.Month(), day, task.schedule.TimeOfDay.Hour(), task.schedule.TimeOfDay.Minute(), 0, 0, now.Location()) - if nextRun.After(now) { - return nextRun - } - } - nextMonth := now.AddDate(0, 1, 0) - return time.Date(nextMonth.Year(), nextMonth.Month(), task.schedule.DayOfMonth[0], task.schedule.TimeOfDay.Hour(), task.schedule.TimeOfDay.Minute(), 0, 0, now.Location()) - } - if len(task.schedule.DayOfWeek) > 0 { - for _, weekday := range task.schedule.DayOfWeek { - nextRun := nextWeekday(now, weekday).Truncate(time.Minute).Add(task.schedule.TimeOfDay.Sub(time.Time{})) - if nextRun.After(now) { - return nextRun - } - } - } - return now -} - -func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time { - cronSpecs, err := parseCronSpec(task.schedule.CronSpec) - if err != nil { - fmt.Println(fmt.Sprintf("Invalid CRON spec format: %s", err.Error())) - return now - } - nextRun := now - nextRun = task.applyCronField(nextRun, cronSpecs.Minute, "minute") - nextRun = task.applyCronField(nextRun, cronSpecs.Hour, "hour") - nextRun = task.applyCronField(nextRun, cronSpecs.DayOfMonth, "day") - nextRun = task.applyCronField(nextRun, cronSpecs.Month, "month") - nextRun = task.applyCronField(nextRun, cronSpecs.DayOfWeek, "weekday") - return nextRun -} - -func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time { - switch fieldSpec { - case "*": - return t - default: - value, _ := strconv.Atoi(fieldSpec) - switch unit { - case "minute": - if t.Minute() > value { - t = t.Add(time.Hour) - } - t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), value, 0, 0, t.Location()) - case "hour": - if t.Hour() > value { - t = t.AddDate(0, 0, 1) - } - t = time.Date(t.Year(), t.Month(), t.Day(), value, t.Minute(), 0, 0, t.Location()) - case "day": - if t.Day() > value { - t = t.AddDate(0, 1, 0) - } - t = time.Date(t.Year(), t.Month(), value, t.Hour(), t.Minute(), 0, 0, t.Location()) - case "month": - if int(t.Month()) > value { - t = t.AddDate(1, 0, 0) - } - t = time.Date(t.Year(), time.Month(value), t.Day(), t.Hour(), t.Minute(), 0, 0, t.Location()) - case "weekday": - weekday := time.Weekday(value) - for t.Weekday() != weekday { - t = t.AddDate(0, 0, 1) - } - } - return t - } -} - -func nextWeekday(t time.Time, weekday time.Weekday) time.Time { - daysUntil := (int(weekday) - int(t.Weekday()) + 7) % 7 - if daysUntil == 0 { - daysUntil = 7 - } - return t.AddDate(0, 0, daysUntil) -} - -func (s *Scheduler) executeTask(task ScheduledTask) { - if task.config.Overlap || len(task.schedule.DayOfWeek) == 0 { - go func() { - result := task.handler(task.ctx, task.payload) - task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result}) - if task.config.Callback != nil { - _ = task.config.Callback(task.ctx, result) - } - fmt.Printf("Executed scheduled task: %s\n", task.payload.ID) - }() - } -} - -func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) { - s.mu.Lock() - defer s.mu.Unlock() - - // Create a default options instance - options := defaultSchedulerOptions() - for _, opt := range opts { - opt(options) - } - if options.Handler == nil { - options.Handler = s.pool.handler - } - if options.Callback == nil { - options.Callback = s.pool.callback - } - stop := make(chan struct{}) - - // Create a new ScheduledTask using the provided options - s.tasks = append(s.tasks, ScheduledTask{ - ctx: ctx, - handler: options.Handler, - payload: payload, - stop: stop, - config: SchedulerConfig{ - Callback: options.Callback, - Overlap: options.Overlap, - }, - schedule: &Schedule{ - Interval: options.Interval, - Recurring: options.Recurring, - }, - }) - - // Start scheduling the task - go s.schedule(s.tasks[len(s.tasks)-1]) -} - -func (s *Scheduler) RemoveTask(payloadID string) { - s.mu.Lock() - defer s.mu.Unlock() - for i, task := range s.tasks { - if task.payload.ID == payloadID { - close(task.stop) - s.tasks = append(s.tasks[:i], s.tasks[i+1:]...) - break - } - } -} - type Pool struct { + taskStorage TaskStorage taskQueue PriorityQueue taskQueueLock sync.Mutex stop chan struct{} @@ -452,7 +28,7 @@ type Pool struct { totalTasks int numOfWorkers int32 paused bool - Scheduler *Scheduler + scheduler *Scheduler totalScheduledTasks int } @@ -460,7 +36,8 @@ func NewPool( numOfWorkers, taskQueueSize int, maxMemoryLoad int64, handler Handler, - callback Callback) *Pool { + callback Callback, + storage TaskStorage) *Pool { pool := &Pool{ taskQueue: make(PriorityQueue, 0, taskQueueSize), stop: make(chan struct{}), @@ -468,16 +45,25 @@ func NewPool( maxMemoryLoad: maxMemoryLoad, handler: handler, callback: callback, + taskStorage: storage, workerAdjust: make(chan int), } - pool.Scheduler = NewScheduler(pool) + pool.scheduler = NewScheduler(pool) heap.Init(&pool.taskQueue) - pool.Scheduler.Start() + pool.scheduler.Start() pool.Start(numOfWorkers) return pool } func (wp *Pool) Start(numWorkers int) { + storedTasks, err := wp.taskStorage.GetAllTasks() + if err == nil { + wp.taskQueueLock.Lock() + for _, task := range storedTasks { + heap.Push(&wp.taskQueue, task) + } + wp.taskQueueLock.Unlock() + } for i := 0; i < numWorkers; i++ { wp.wg.Add(1) go wp.worker() @@ -492,9 +78,20 @@ func (wp *Pool) worker() { select { case <-wp.taskNotify: wp.taskQueueLock.Lock() + var task *QueueTask if len(wp.taskQueue) > 0 && !wp.paused { - task := heap.Pop(&wp.taskQueue).(*QueueTask) - wp.taskQueueLock.Unlock() + task = heap.Pop(&wp.taskQueue).(*QueueTask) + } + wp.taskQueueLock.Unlock() + if task == nil && !wp.paused { + var err error + task, err = wp.taskStorage.FetchNextTask() + if err != nil { + + continue + } + } + if task != nil { taskSize := int64(utils.SizeOf(task.payload)) wp.totalMemoryUsed += taskSize wp.totalTasks++ @@ -508,10 +105,11 @@ func (wp *Pool) worker() { if err := wp.callback(task.ctx, result); err != nil { wp.errorCount++ } + } + if err := wp.taskStorage.DeleteTask(task.payload.ID); err != nil { + } wp.totalMemoryUsed -= taskSize - } else { - wp.taskQueueLock.Unlock() } case <-wp.stop: return @@ -549,13 +147,16 @@ func (wp *Pool) adjustWorkers(newWorkerCount int) { atomic.StoreInt32(&wp.numOfWorkers, int32(newWorkerCount)) } -func (wp *Pool) AddTask(ctx context.Context, payload *Task, priority int) error { +func (wp *Pool) EnqueueTask(ctx context.Context, payload *Task, priority int) error { if payload.ID == "" { payload.ID = NewID() } + task := &QueueTask{ctx: ctx, payload: payload, priority: priority} + if err := wp.taskStorage.SaveTask(task); err != nil { + return err + } wp.taskQueueLock.Lock() defer wp.taskQueueLock.Unlock() - task := &QueueTask{ctx: ctx, payload: payload, priority: priority} taskSize := int64(utils.SizeOf(payload)) if wp.totalMemoryUsed+taskSize > wp.maxMemoryLoad && wp.maxMemoryLoad > 0 { return fmt.Errorf("max memory load reached, cannot add task of size %d", taskSize) @@ -590,34 +191,9 @@ func (wp *Pool) AdjustWorkerCount(newWorkerCount int) { func (wp *Pool) PrintMetrics() { fmt.Printf("Total Tasks: %d, Completed Tasks: %d, Error Count: %d, Total Memory Used: %d bytes, Total Scheduled Tasks: %d\n", - wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed, len(wp.Scheduler.tasks)) + wp.totalTasks, wp.completedTasks, wp.errorCount, wp.totalMemoryUsed, len(wp.scheduler.tasks)) } -type ExecutionHistory struct { - Timestamp time.Time - Result Result -} - -func (s *Scheduler) PrintAllTasks() { - s.mu.Lock() - defer s.mu.Unlock() - fmt.Println("Scheduled Tasks:") - for _, task := range s.tasks { - fmt.Printf("Task ID: %s, Next Execution: %s\n", task.payload.ID, task.getNextRunTime(time.Now())) - } -} - -func (s *Scheduler) PrintExecutionHistory(taskID string) { - s.mu.Lock() - defer s.mu.Unlock() - for _, task := range s.tasks { - if task.payload.ID == taskID { - fmt.Printf("Execution History for Task ID: %s\n", taskID) - for _, history := range task.executionHistory { - fmt.Printf("Timestamp: %s, Result: %v\n", history.Timestamp, history.Result.Error) - } - return - } - } - fmt.Printf("No task found with ID: %s\n", taskID) +func (wp *Pool) Scheduler() *Scheduler { + return wp.scheduler } diff --git a/queue.go b/queue.go index 6de8a23..8b4428c 100644 --- a/queue.go +++ b/queue.go @@ -1,6 +1,8 @@ package mq import ( + "context" + "github.com/oarkflow/mq/storage" "github.com/oarkflow/mq/storage/memory" ) @@ -29,3 +31,32 @@ func (b *Broker) NewQueue(qName string) *Queue { go b.dispatchWorker(q) return q } + +type QueueTask struct { + ctx context.Context + payload *Task + priority int +} + +type PriorityQueue []*QueueTask + +func (pq PriorityQueue) Len() int { return len(pq) } + +func (pq PriorityQueue) Less(i, j int) bool { + return pq[i].priority > pq[j].priority +} + +func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } + +func (pq *PriorityQueue) Push(x interface{}) { + item := x.(*QueueTask) + *pq = append(*pq, item) +} + +func (pq *PriorityQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + *pq = old[0 : n-1] + return item +} diff --git a/scheduler.go b/scheduler.go new file mode 100644 index 0000000..8cc88dc --- /dev/null +++ b/scheduler.go @@ -0,0 +1,432 @@ +package mq + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "time" +) + +type ScheduleOptions struct { + Handler Handler + Callback Callback + Overlap bool + Interval time.Duration + Recurring bool +} + +type SchedulerOption func(*ScheduleOptions) + +// Helper functions to create SchedulerOptions +func WithSchedulerHandler(handler Handler) SchedulerOption { + return func(opts *ScheduleOptions) { + opts.Handler = handler + } +} + +func WithSchedulerCallback(callback Callback) SchedulerOption { + return func(opts *ScheduleOptions) { + opts.Callback = callback + } +} + +func WithOverlap() SchedulerOption { + return func(opts *ScheduleOptions) { + opts.Overlap = true + } +} + +func WithInterval(interval time.Duration) SchedulerOption { + return func(opts *ScheduleOptions) { + opts.Interval = interval + } +} + +func WithRecurring() SchedulerOption { + return func(opts *ScheduleOptions) { + opts.Recurring = true + } +} + +// defaultOptions returns the default scheduling options +func defaultSchedulerOptions() *ScheduleOptions { + return &ScheduleOptions{ + Interval: time.Minute, + Recurring: true, + } +} + +type SchedulerConfig struct { + Callback Callback + Overlap bool +} + +type ScheduledTask struct { + ctx context.Context + handler Handler + payload *Task + config SchedulerConfig + schedule *Schedule + stop chan struct{} + executionHistory []ExecutionHistory +} + +type Schedule struct { + Interval time.Duration + DayOfWeek []time.Weekday + DayOfMonth []int + TimeOfDay time.Time + Recurring bool + CronSpec string +} + +func (s *Schedule) ToHumanReadable() string { + var sb strings.Builder + if s.CronSpec != "" { + cronDescription, err := parseCronSpec(s.CronSpec) + if err != nil { + sb.WriteString(fmt.Sprintf("Invalid CRON spec: %s\n", err.Error())) + } else { + sb.WriteString(fmt.Sprintf("CRON-based schedule: %s\n", cronDescription)) + } + } + if s.Interval > 0 { + sb.WriteString(fmt.Sprintf("Recurring every %s\n", s.Interval)) + } + if len(s.DayOfMonth) > 0 { + sb.WriteString("Occurs on the following days of the month: ") + for i, day := range s.DayOfMonth { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(fmt.Sprintf("%d", day)) + } + sb.WriteString("\n") + } + if len(s.DayOfWeek) > 0 { + sb.WriteString("Occurs on the following days of the week: ") + for i, day := range s.DayOfWeek { + if i > 0 { + sb.WriteString(", ") + } + sb.WriteString(day.String()) + } + sb.WriteString("\n") + } + if !s.TimeOfDay.IsZero() { + sb.WriteString(fmt.Sprintf("Time of day: %s\n", s.TimeOfDay.Format("15:04"))) + } + if s.Recurring { + sb.WriteString("This schedule is recurring.\n") + } else { + sb.WriteString("This schedule is one-time.\n") + } + if sb.Len() == 0 { + sb.WriteString("No schedule defined.") + } + return sb.String() +} + +type CronSchedule struct { + Minute string + Hour string + DayOfMonth string + Month string + DayOfWeek string +} + +func (c CronSchedule) String() string { + return fmt.Sprintf("At %s minute(s) past %s, on %s, during %s, every %s", c.Minute, c.Hour, c.DayOfWeek, c.Month, c.DayOfMonth) +} + +func parseCronSpec(cronSpec string) (CronSchedule, error) { + parts := strings.Fields(cronSpec) + if len(parts) != 5 { + return CronSchedule{}, fmt.Errorf("invalid CRON spec: expected 5 fields, got %d", len(parts)) + } + minute, err := cronFieldToString(parts[0], "minute") + if err != nil { + return CronSchedule{}, err + } + hour, err := cronFieldToString(parts[1], "hour") + if err != nil { + return CronSchedule{}, err + } + dayOfMonth, err := cronFieldToString(parts[2], "day of the month") + if err != nil { + return CronSchedule{}, err + } + month, err := cronFieldToString(parts[3], "month") + if err != nil { + return CronSchedule{}, err + } + dayOfWeek, err := cronFieldToString(parts[4], "day of the week") + if err != nil { + return CronSchedule{}, err + } + return CronSchedule{ + Minute: minute, + Hour: hour, + DayOfMonth: dayOfMonth, + Month: month, + DayOfWeek: dayOfWeek, + }, nil +} + +func cronFieldToString(field string, fieldName string) (string, error) { + switch field { + case "*": + return fmt.Sprintf("every %s", fieldName), nil + default: + values, err := parseCronValue(field) + if err != nil { + return "", fmt.Errorf("invalid %s field: %s", fieldName, err.Error()) + } + return fmt.Sprintf("%s %s", strings.Join(values, ", "), fieldName), nil + } +} + +func parseCronValue(field string) ([]string, error) { + var values []string + ranges := strings.Split(field, ",") + for _, r := range ranges { + if strings.Contains(r, "-") { + bounds := strings.Split(r, "-") + if len(bounds) != 2 { + return nil, fmt.Errorf("invalid range: %s", r) + } + start, err := strconv.Atoi(bounds[0]) + if err != nil { + return nil, err + } + end, err := strconv.Atoi(bounds[1]) + if err != nil { + return nil, err + } + for i := start; i <= end; i++ { + values = append(values, strconv.Itoa(i)) + } + } else { + values = append(values, r) + } + } + return values, nil +} + +type Scheduler struct { + tasks []ScheduledTask + mu sync.Mutex + pool *Pool +} + +func (s *Scheduler) Start() { + for _, task := range s.tasks { + go s.schedule(task) + } +} + +func (s *Scheduler) schedule(task ScheduledTask) { + if task.schedule.Interval > 0 { + ticker := time.NewTicker(task.schedule.Interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + s.executeTask(task) + case <-task.stop: + return + } + } + } else if task.schedule.Recurring { + for { + now := time.Now() + nextRun := task.getNextRunTime(now) + if nextRun.After(now) { + time.Sleep(nextRun.Sub(now)) + } + s.executeTask(task) + } + } +} + +func NewScheduler(pool *Pool) *Scheduler { + return &Scheduler{pool: pool} +} + +func (task ScheduledTask) getNextRunTime(now time.Time) time.Time { + if task.schedule.CronSpec != "" { + return task.getNextCronRunTime(now) + } + if len(task.schedule.DayOfMonth) > 0 { + for _, day := range task.schedule.DayOfMonth { + nextRun := time.Date(now.Year(), now.Month(), day, task.schedule.TimeOfDay.Hour(), task.schedule.TimeOfDay.Minute(), 0, 0, now.Location()) + if nextRun.After(now) { + return nextRun + } + } + nextMonth := now.AddDate(0, 1, 0) + return time.Date(nextMonth.Year(), nextMonth.Month(), task.schedule.DayOfMonth[0], task.schedule.TimeOfDay.Hour(), task.schedule.TimeOfDay.Minute(), 0, 0, now.Location()) + } + if len(task.schedule.DayOfWeek) > 0 { + for _, weekday := range task.schedule.DayOfWeek { + nextRun := nextWeekday(now, weekday).Truncate(time.Minute).Add(task.schedule.TimeOfDay.Sub(time.Time{})) + if nextRun.After(now) { + return nextRun + } + } + } + return now +} + +func (task ScheduledTask) getNextCronRunTime(now time.Time) time.Time { + cronSpecs, err := parseCronSpec(task.schedule.CronSpec) + if err != nil { + fmt.Println(fmt.Sprintf("Invalid CRON spec format: %s", err.Error())) + return now + } + nextRun := now + nextRun = task.applyCronField(nextRun, cronSpecs.Minute, "minute") + nextRun = task.applyCronField(nextRun, cronSpecs.Hour, "hour") + nextRun = task.applyCronField(nextRun, cronSpecs.DayOfMonth, "day") + nextRun = task.applyCronField(nextRun, cronSpecs.Month, "month") + nextRun = task.applyCronField(nextRun, cronSpecs.DayOfWeek, "weekday") + return nextRun +} + +func (task ScheduledTask) applyCronField(t time.Time, fieldSpec string, unit string) time.Time { + switch fieldSpec { + case "*": + return t + default: + value, _ := strconv.Atoi(fieldSpec) + switch unit { + case "minute": + if t.Minute() > value { + t = t.Add(time.Hour) + } + t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), value, 0, 0, t.Location()) + case "hour": + if t.Hour() > value { + t = t.AddDate(0, 0, 1) + } + t = time.Date(t.Year(), t.Month(), t.Day(), value, t.Minute(), 0, 0, t.Location()) + case "day": + if t.Day() > value { + t = t.AddDate(0, 1, 0) + } + t = time.Date(t.Year(), t.Month(), value, t.Hour(), t.Minute(), 0, 0, t.Location()) + case "month": + if int(t.Month()) > value { + t = t.AddDate(1, 0, 0) + } + t = time.Date(t.Year(), time.Month(value), t.Day(), t.Hour(), t.Minute(), 0, 0, t.Location()) + case "weekday": + weekday := time.Weekday(value) + for t.Weekday() != weekday { + t = t.AddDate(0, 0, 1) + } + } + return t + } +} + +func nextWeekday(t time.Time, weekday time.Weekday) time.Time { + daysUntil := (int(weekday) - int(t.Weekday()) + 7) % 7 + if daysUntil == 0 { + daysUntil = 7 + } + return t.AddDate(0, 0, daysUntil) +} + +func (s *Scheduler) executeTask(task ScheduledTask) { + if task.config.Overlap || len(task.schedule.DayOfWeek) == 0 { + go func() { + result := task.handler(task.ctx, task.payload) + task.executionHistory = append(task.executionHistory, ExecutionHistory{Timestamp: time.Now(), Result: result}) + if task.config.Callback != nil { + _ = task.config.Callback(task.ctx, result) + } + fmt.Printf("Executed scheduled task: %s\n", task.payload.ID) + }() + } +} + +func (s *Scheduler) AddTask(ctx context.Context, payload *Task, opts ...SchedulerOption) { + s.mu.Lock() + defer s.mu.Unlock() + + // Create a default options instance + options := defaultSchedulerOptions() + for _, opt := range opts { + opt(options) + } + if options.Handler == nil { + options.Handler = s.pool.handler + } + if options.Callback == nil { + options.Callback = s.pool.callback + } + stop := make(chan struct{}) + + // Create a new ScheduledTask using the provided options + s.tasks = append(s.tasks, ScheduledTask{ + ctx: ctx, + handler: options.Handler, + payload: payload, + stop: stop, + config: SchedulerConfig{ + Callback: options.Callback, + Overlap: options.Overlap, + }, + schedule: &Schedule{ + Interval: options.Interval, + Recurring: options.Recurring, + }, + }) + + // Start scheduling the task + go s.schedule(s.tasks[len(s.tasks)-1]) +} + +func (s *Scheduler) RemoveTask(payloadID string) { + s.mu.Lock() + defer s.mu.Unlock() + for i, task := range s.tasks { + if task.payload.ID == payloadID { + close(task.stop) + s.tasks = append(s.tasks[:i], s.tasks[i+1:]...) + break + } + } +} + +type ExecutionHistory struct { + Timestamp time.Time + Result Result +} + +func (s *Scheduler) PrintAllTasks() { + s.mu.Lock() + defer s.mu.Unlock() + fmt.Println("Scheduled Tasks:") + for _, task := range s.tasks { + fmt.Printf("Task ID: %s, Next Execution: %s\n", task.payload.ID, task.getNextRunTime(time.Now())) + } +} + +func (s *Scheduler) PrintExecutionHistory(taskID string) { + s.mu.Lock() + defer s.mu.Unlock() + for _, task := range s.tasks { + if task.payload.ID == taskID { + fmt.Printf("Execution History for Task ID: %s\n", taskID) + for _, history := range task.executionHistory { + fmt.Printf("Timestamp: %s, Result: %v\n", history.Timestamp, history.Result.Error) + } + return + } + } + fmt.Printf("No task found with ID: %s\n", taskID) +} diff --git a/storage.go b/storage.go new file mode 100644 index 0000000..0430a22 --- /dev/null +++ b/storage.go @@ -0,0 +1,203 @@ +package mq + +import ( + "container/heap" + "fmt" + "sync" + "time" +) + +type TaskStorage interface { + SaveTask(task *QueueTask) error + GetTask(taskID string) (*QueueTask, error) + DeleteTask(taskID string) error + GetAllTasks() ([]*QueueTask, error) + FetchNextTask() (*QueueTask, error) + CleanupExpiredTasks() error +} + +type MemoryTaskStorage struct { + tasks PriorityQueue + taskLock sync.Mutex + expiryTime time.Duration +} + +func NewMemoryTaskStorage(expiryTime time.Duration) *MemoryTaskStorage { + return &MemoryTaskStorage{ + tasks: make(PriorityQueue, 0), + expiryTime: expiryTime, + } +} + +func (m *MemoryTaskStorage) SaveTask(task *QueueTask) error { + m.taskLock.Lock() + defer m.taskLock.Unlock() + heap.Push(&m.tasks, task) + return nil +} + +func (m *MemoryTaskStorage) GetTask(taskID string) (*QueueTask, error) { + m.taskLock.Lock() + defer m.taskLock.Unlock() + for _, task := range m.tasks { + if task.payload.ID == taskID { + return task, nil + } + } + return nil, fmt.Errorf("task not found") +} + +func (m *MemoryTaskStorage) DeleteTask(taskID string) error { + m.taskLock.Lock() + defer m.taskLock.Unlock() + for i, task := range m.tasks { + if task.payload.ID == taskID { + heap.Remove(&m.tasks, i) + return nil + } + } + return fmt.Errorf("task not found") +} + +func (m *MemoryTaskStorage) GetAllTasks() ([]*QueueTask, error) { + m.taskLock.Lock() + defer m.taskLock.Unlock() + tasks := make([]*QueueTask, len(m.tasks)) + for i, task := range m.tasks { + tasks[i] = task + } + return tasks, nil +} + +func (m *MemoryTaskStorage) FetchNextTask() (*QueueTask, error) { + m.taskLock.Lock() + defer m.taskLock.Unlock() + if len(m.tasks) == 0 { + return nil, fmt.Errorf("no tasks available") + } + + task := heap.Pop(&m.tasks).(*QueueTask) + if task.payload.CreatedAt.Add(m.expiryTime).Before(time.Now()) { + m.DeleteTask(task.payload.ID) + return m.FetchNextTask() + } + return task, nil +} + +func (m *MemoryTaskStorage) CleanupExpiredTasks() error { + m.taskLock.Lock() + defer m.taskLock.Unlock() + + for i := 0; i < len(m.tasks); i++ { + task := m.tasks[i] + if task.payload.CreatedAt.Add(m.expiryTime).Before(time.Now()) { + heap.Remove(&m.tasks, i) + i-- // Adjust index after removal + } + } + return nil +} + +/* +type PostgresTaskStorage struct { + db *sql.DB +} + +func NewPostgresTaskStorage(db *sql.DB) *PostgresTaskStorage { + return &PostgresTaskStorage{db: db} +} + +func (p *PostgresTaskStorage) SaveTask(task *QueueTask) error { + query := ` + INSERT INTO tasks (id, payload, priority, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (id) DO NOTHING` + payloadBytes, err := utils.Serialize(task.payload) // Serialize converts the task to bytes + if err != nil { + return err + } + _, err = p.db.Exec(query, task.payload.ID, payloadBytes, task.priority, task.payload.CreatedAt, time.Now()) + return err +} + +func (p *PostgresTaskStorage) GetTask(taskID string) (*QueueTask, error) { + query := `SELECT id, payload, priority, created_at FROM tasks WHERE id = $1` + var task QueueTask + var payloadBytes []byte + err := p.db.QueryRow(query, taskID).Scan(&task.payload.ID, &payloadBytes, &task.priority, &task.payload.CreatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("task not found") + } + return nil, err + } + + task.payload, err = utils.Deserialize(payloadBytes) // Deserialize converts bytes to Task object + if err != nil { + return nil, err + } + return &task, nil +} + +func (p *PostgresTaskStorage) DeleteTask(taskID string) error { + query := `DELETE FROM tasks WHERE id = $1` + _, err := p.db.Exec(query, taskID) + return err +} + +func (p *PostgresTaskStorage) GetAllTasks() ([]*QueueTask, error) { + query := `SELECT id, payload, priority, created_at FROM tasks` + rows, err := p.db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + var tasks []*QueueTask + for rows.Next() { + var task QueueTask + var payloadBytes []byte + err := rows.Scan(&task.payload.ID, &payloadBytes, &task.priority, &task.payload.CreatedAt) + if err != nil { + return nil, err + } + + task.payload, err = utils.Deserialize(payloadBytes) + if err != nil { + return nil, err + } + + tasks = append(tasks, &task) + } + return tasks, nil +} + +func (p *PostgresTaskStorage) FetchNextTask() (*QueueTask, error) { + query := ` + SELECT id, payload, priority, created_at FROM tasks + ORDER BY priority DESC, created_at ASC + LIMIT 1` + + var task QueueTask + var payloadBytes []byte + err := p.db.QueryRow(query).Scan(&task.payload.ID, &payloadBytes, &task.priority, &task.payload.CreatedAt) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("no tasks available") + } + return nil, err + } + + task.payload, err = utils.Deserialize(payloadBytes) + if err != nil { + return nil, err + } + return &task, nil +} + +func (p *PostgresTaskStorage) CleanupExpiredTasks() error { + query := `DELETE FROM tasks WHERE created_at < $1` + _, err := p.db.Exec(query, time.Now().Add(-time.Hour*24)) // Assuming tasks older than 24 hours are expired + return err +} +*/