Files
mq/mq.go
2025-03-29 17:38:47 +05:45

1268 lines
33 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mq
import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
"strings"
"sync"
"time"
"github.com/oarkflow/errors"
"github.com/oarkflow/json"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/jsonparser"
"github.com/oarkflow/mq/logger"
"github.com/oarkflow/mq/storage"
"github.com/oarkflow/mq/storage/memory"
"github.com/oarkflow/mq/utils"
)
type Status string
const (
Pending Status = "Pending"
Processing Status = "Processing"
Completed Status = "Completed"
Failed Status = "Failed"
)
type Result struct {
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at,omitempty"`
Latency string `json:"latency"`
Error error `json:"-"` // Keep error as an error type
Topic string `json:"topic"`
TaskID string `json:"task_id"`
Status Status `json:"status"`
ConditionStatus string `json:"condition_status"`
Ctx context.Context `json:"-"`
Payload json.RawMessage `json:"payload"`
Last bool
}
func (r Result) MarshalJSON() ([]byte, error) {
type Alias Result
aux := &struct {
ErrorMsg string `json:"error,omitempty"`
Alias
}{
Alias: (Alias)(r),
}
if r.Error != nil {
aux.ErrorMsg = r.Error.Error()
}
return json.Marshal(aux)
}
func (r *Result) UnmarshalJSON(data []byte) error {
type Alias Result
aux := &struct {
*Alias
ErrMsg string `json:"error,omitempty"`
}{
Alias: (*Alias)(r),
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
if aux.ErrMsg != "" {
r.Error = errors.New(aux.ErrMsg)
} else {
r.Error = nil
}
return nil
}
func (r Result) Unmarshal(data any) error {
if r.Payload == nil {
return fmt.Errorf("payload is nil")
}
return json.Unmarshal(r.Payload, data)
}
func HandleError(ctx context.Context, err error, status ...Status) Result {
st := Failed
if len(status) > 0 {
st = status[0]
}
if err == nil {
return Result{Ctx: ctx}
}
return Result{
Ctx: ctx,
Status: st,
Error: err,
}
}
func (r Result) WithData(status Status, data []byte) Result {
if r.Error != nil {
return r
}
return Result{
Status: status,
Payload: data,
Ctx: r.Ctx,
}
}
type TLSConfig struct {
CertPath string
KeyPath string
CAPath string
UseTLS bool
}
type Options struct {
storage TaskStorage
consumerOnSubscribe func(ctx context.Context, topic, consumerName string)
consumerOnClose func(ctx context.Context, topic, consumerName string)
notifyResponse func(context.Context, Result) error
brokerAddr string
tlsConfig TLSConfig
callback []func(context.Context, Result) Result
queueSize int
initialDelay time.Duration
maxBackoff time.Duration
jitterPercent float64
maxRetries int
numOfWorkers int
maxMemoryLoad int64
syncMode bool
cleanTaskOnComplete bool
enableWorkerPool bool
respondPendingResult bool
logger logger.Logger
}
func (o *Options) SetSyncMode(sync bool) {
o.syncMode = sync
}
func (o *Options) NumOfWorkers() int {
return o.numOfWorkers
}
func (o *Options) Logger() logger.Logger {
return o.logger
}
func (o *Options) Storage() TaskStorage {
return o.storage
}
func (o *Options) CleanTaskOnComplete() bool {
return o.cleanTaskOnComplete
}
func (o *Options) QueueSize() int {
return o.queueSize
}
func (o *Options) MaxMemoryLoad() int64 {
return o.maxMemoryLoad
}
func (o *Options) BrokerAddr() string {
return o.brokerAddr
}
func HeadersWithConsumerID(ctx context.Context, id string) map[string]string {
return WithHeaders(ctx, map[string]string{consts.ConsumerKey: id, consts.ContentType: consts.TypeJson})
}
func HeadersWithConsumerIDAndQueue(ctx context.Context, id, queue string) map[string]string {
return WithHeaders(ctx, map[string]string{
consts.ConsumerKey: id,
consts.ContentType: consts.TypeJson,
consts.QueueKey: queue,
})
}
type QueuedTask struct {
Message *codec.Message
RetryCount int
}
type consumer struct {
conn net.Conn
id string
state consts.ConsumerState
}
type publisher struct {
conn net.Conn
id string
}
type Broker struct {
queues storage.IMap[string, *Queue]
consumers storage.IMap[string, *consumer]
publishers storage.IMap[string, *publisher]
deadLetter storage.IMap[string, *Queue]
opts *Options
listener net.Listener
}
func NewBroker(opts ...Option) *Broker {
options := SetupOptions(opts...)
return &Broker{
queues: memory.New[string, *Queue](),
publishers: memory.New[string, *publisher](),
consumers: memory.New[string, *consumer](),
deadLetter: memory.New[string, *Queue](),
opts: options,
}
}
func (b *Broker) Options() *Options {
return b.opts
}
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
consumerID, ok := GetConsumerID(ctx)
if ok && consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
con.conn.Close()
b.consumers.Del(consumerID)
}
b.queues.ForEach(func(_ string, queue *Queue) bool {
if _, ok := queue.consumers.Get(consumerID); ok {
if b.opts.consumerOnClose != nil {
b.opts.consumerOnClose(ctx, queue.name, consumerID)
}
queue.consumers.Del(consumerID)
}
return true
})
} else {
b.consumers.ForEach(func(consumerID string, con *consumer) bool {
if utils.ConnectionsEqual(conn, con.conn) {
con.conn.Close()
b.consumers.Del(consumerID)
b.queues.ForEach(func(_ string, queue *Queue) bool {
queue.consumers.Del(consumerID)
if _, ok := queue.consumers.Get(consumerID); ok {
if b.opts.consumerOnClose != nil {
b.opts.consumerOnClose(ctx, queue.name, consumerID)
}
}
return true
})
}
return true
})
}
publisherID, ok := GetPublisherID(ctx)
if ok && publisherID != "" {
if con, exists := b.publishers.Get(publisherID); exists {
con.conn.Close()
b.publishers.Del(publisherID)
}
}
return nil
}
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
if conn != nil {
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
}
}
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
switch msg.Command {
case consts.PUBLISH:
b.PublishHandler(ctx, conn, msg)
case consts.SUBSCRIBE:
b.SubscribeHandler(ctx, conn, msg)
case consts.MESSAGE_RESPONSE:
b.MessageResponseHandler(ctx, msg)
case consts.MESSAGE_ACK:
b.MessageAck(ctx, msg)
case consts.MESSAGE_DENY:
b.MessageDeny(ctx, msg)
case consts.CONSUMER_PAUSED:
b.OnConsumerPause(ctx, msg)
case consts.CONSUMER_RESUMED:
b.OnConsumerResume(ctx, msg)
case consts.CONSUMER_STOPPED:
b.OnConsumerStop(ctx, msg)
default:
log.Printf("BROKER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
}
}
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
}
func (b *Broker) MessageDeny(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
taskError, _ := jsonparser.GetString(msg.Payload, "error")
log.Printf("BROKER - MESSAGE_DENY ~> %s on %s for Task %s, Error: %s", consumerID, msg.Queue, taskID, taskError)
}
func (b *Broker) OnConsumerPause(ctx context.Context, _ *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
if consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
con.state = consts.ConsumerStatePaused
log.Printf("BROKER - CONSUMER ~> Paused %s", consumerID)
}
}
}
func (b *Broker) OnConsumerStop(ctx context.Context, _ *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
if consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
con.state = consts.ConsumerStateStopped
log.Printf("BROKER - CONSUMER ~> Stopped %s", consumerID)
}
}
}
func (b *Broker) OnConsumerResume(ctx context.Context, _ *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
if consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
con.state = consts.ConsumerStateActive
log.Printf("BROKER - CONSUMER ~> Resumed %s", consumerID)
}
}
}
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
msg.Command = consts.RESPONSE
b.HandleCallback(ctx, msg)
awaitResponse, ok := GetAwaitResponse(ctx)
if !(ok && awaitResponse == "true") {
return
}
publisherID, exists := GetPublisherID(ctx)
if !exists {
return
}
con, ok := b.publishers.Get(publisherID)
if !ok {
return
}
err := b.send(ctx, con.conn, msg)
if err != nil {
panic(err)
}
}
func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error {
headers, _ := GetHeaders(ctx)
payload, err := json.Marshal(task)
if err != nil {
return err
}
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
b.broadcastToConsumers(msg)
return nil
}
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
pub := b.addPublisher(ctx, msg.Queue, conn)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
}
b.broadcastToConsumers(msg)
go func() {
select {
case <-ctx.Done():
b.publishers.Del(pub.id)
}
}()
}
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
}
if b.opts.consumerOnSubscribe != nil {
b.opts.consumerOnSubscribe(ctx, msg.Queue, consumerID)
}
go func() {
select {
case <-ctx.Done():
b.RemoveConsumer(consumerID, msg.Queue)
}
}()
}
func (b *Broker) Start(ctx context.Context) error {
var listener net.Listener
var err error
if b.opts.tlsConfig.UseTLS {
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
if err != nil {
return fmt.Errorf("failed to load TLS certificates: %v", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
if err != nil {
return fmt.Errorf("failed to start TLS listener: %v", err)
}
log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr)
} else {
listener, err = net.Listen("tcp", b.opts.brokerAddr)
if err != nil {
return fmt.Errorf("failed to start TCP listener: %v", err)
}
log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr)
}
b.listener = listener
defer b.Close()
const maxConcurrentConnections = 100
sem := make(chan struct{}, maxConcurrentConnections)
for {
conn, err := listener.Accept()
if err != nil {
b.OnError(ctx, conn, err)
time.Sleep(50 * time.Millisecond)
continue
}
sem <- struct{}{}
go func(c net.Conn) {
defer func() {
<-sem
c.Close()
}()
for {
err := b.readMessage(ctx, c)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
log.Println("Temporary network error, retrying:", netErr)
continue
}
log.Println("Connection closed due to error:", err)
break
}
}
}(conn)
}
}
func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(ctx, conn, msg)
}
func (b *Broker) receive(ctx context.Context, c net.Conn) (*codec.Message, error) {
return codec.ReadMessage(ctx, c)
}
func (b *Broker) broadcastToConsumers(msg *codec.Message) {
if queue, ok := b.queues.Get(msg.Queue); ok {
task := &QueuedTask{Message: msg, RetryCount: 0}
queue.tasks <- task
}
}
func (b *Broker) waitForConsumerAck(ctx context.Context, conn net.Conn) error {
msg, err := b.receive(ctx, conn)
if err != nil {
return err
}
if msg.Command == consts.MESSAGE_ACK {
log.Println("Received CONSUMER_ACK: Subscribed successfully")
return nil
}
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
}
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
publisherID, ok := GetPublisherID(ctx)
_, ok = b.queues.Get(queueName)
if !ok {
b.NewQueue(queueName)
}
con := &publisher{id: publisherID, conn: conn}
b.publishers.Set(publisherID, con)
return con
}
func (b *Broker) AddConsumer(ctx context.Context, queueName string, conn net.Conn) string {
consumerID, ok := GetConsumerID(ctx)
q, ok := b.queues.Get(queueName)
if !ok {
q = b.NewQueue(queueName)
}
con := &consumer{id: consumerID, conn: conn}
b.consumers.Set(consumerID, con)
q.consumers.Set(consumerID, con)
log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName)
return consumerID
}
func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
if len(queues) > 0 {
for _, queueName := range queues {
if queue, ok := b.queues.Get(queueName); ok {
con, ok := queue.consumers.Get(consumerID)
if ok {
con.conn.Close()
queue.consumers.Del(consumerID)
}
b.queues.Del(queueName)
}
}
return
}
b.queues.ForEach(func(queueName string, queue *Queue) bool {
con, ok := queue.consumers.Get(consumerID)
if ok {
con.conn.Close()
queue.consumers.Del(consumerID)
}
b.queues.Del(queueName)
return true
})
}
func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
fn := func(queue *Queue) {
con, ok := queue.consumers.Get(consumerID)
if ok {
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
err := b.send(ctx, con.conn, ack)
if err == nil {
con.state = state
}
}
}
if len(queues) > 0 {
for _, queueName := range queues {
if queue, ok := b.queues.Get(queueName); ok {
fn(queue)
}
}
return
}
b.queues.ForEach(func(queueName string, queue *Queue) bool {
fn(queue)
return true
})
}
func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
}
func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
}
func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
}
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
msg, err := b.receive(ctx, c)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
b.OnMessage(ctx, msg, c)
return nil
}
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
b.OnClose(ctx, c)
return err
}
b.OnError(ctx, c, err)
return err
}
func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) {
delay := b.opts.initialDelay
for task := range queue.tasks {
success := false
for !success && task.RetryCount <= b.opts.maxRetries {
if b.dispatchTaskToConsumer(ctx, queue, task) {
success = true
} else {
task.RetryCount++
delay = b.backoffRetry(queue, task, delay)
}
}
if task.RetryCount > b.opts.maxRetries {
b.sendToDLQ(queue, task)
}
}
}
func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
id, _ := jsonparser.GetString(task.Message.Payload, "id")
if dlq, ok := b.deadLetter.Get(queue.name); ok {
log.Printf("Sending task %s to dead-letter queue for %s", id, queue.name)
dlq.tasks <- task
} else {
log.Printf("No dead-letter queue for %s, discarding task %s", queue.name, id)
}
}
func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task *QueuedTask) bool {
var consumerFound bool
var err error
queue.consumers.ForEach(func(_ string, con *consumer) bool {
if con.state != consts.ConsumerStateActive {
err = fmt.Errorf("consumer %s is not active", con.id)
return true
}
if err := b.send(ctx, con.conn, task.Message); err == nil {
consumerFound = true
return false
}
return true
})
if err != nil {
log.Println(err.Error())
return false
}
if !consumerFound {
log.Printf("No available consumers for queue %s, retrying...", queue.name)
}
return consumerFound
}
// Modified backoffRetry: Removed reinsertion of the task into queue.tasks.
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
time.Sleep(backoffDuration)
delay *= 2
if delay > b.opts.maxBackoff {
delay = b.opts.maxBackoff
}
return delay
}
func (b *Broker) URL() string {
return b.opts.brokerAddr
}
func (b *Broker) Close() error {
if b != nil && b.listener != nil {
log.Printf("Broker is closing...")
return b.listener.Close()
}
return nil
}
func (b *Broker) SetURL(url string) {
b.opts.brokerAddr = url
}
type Processor interface {
ProcessTask(ctx context.Context, msg *Task) Result
Consume(ctx context.Context) error
Pause(ctx context.Context) error
Resume(ctx context.Context) error
Stop(ctx context.Context) error
Close() error
GetKey() string
SetKey(key string)
GetType() string
}
type Consumer struct {
conn net.Conn
handler Handler
pool *Pool
opts *Options
id string
queue string
}
func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer {
options := SetupOptions(opts...)
return &Consumer{
id: id,
opts: options,
queue: queue,
handler: handler,
}
}
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(ctx, conn, msg)
}
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(ctx, conn)
}
func (c *Consumer) Close() error {
c.pool.Stop()
return c.conn.Close()
}
func (c *Consumer) GetKey() string {
return c.id
}
func (c *Consumer) GetType() string {
return "consumer"
}
func (c *Consumer) SetKey(key string) {
c.id = key
}
func (c *Consumer) Metrics() Metrics {
return c.pool.Metrics()
}
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
if err := c.send(ctx, c.conn, msg); err != nil {
return fmt.Errorf("error while trying to subscribe: %v", err)
}
return c.waitForAck(ctx, c.conn)
}
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
fmt.Println("Consumer closed")
return nil
}
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
}
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) error {
switch msg.Command {
case consts.PUBLISH:
c.ConsumeMessage(ctx, msg, conn)
case consts.CONSUMER_PAUSE:
err := c.Pause(ctx)
if err != nil {
log.Printf("Unable to pause consumer: %v", err)
}
return err
case consts.CONSUMER_RESUME:
err := c.Resume(ctx)
if err != nil {
log.Printf("Unable to resume consumer: %v", err)
}
return err
case consts.CONSUMER_STOP:
err := c.Stop(ctx)
if err != nil {
log.Printf("Unable to stop consumer: %v", err)
}
return err
default:
log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
}
return nil
}
func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) {
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err := c.send(ctx, conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
}
}
func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
c.sendMessageAck(ctx, msg, conn)
if msg.Payload == nil {
log.Printf("Received empty message payload")
return
}
var task Task
err := json.Unmarshal(msg.Payload, &task)
if err != nil {
log.Printf("Error unmarshalling message: %v", err)
return
}
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil {
c.sendDenyMessage(ctx, task.ID, msg.Queue, err)
return
}
}
func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result {
defer RecoverPanic(RecoverTitle)
queue, _ := GetQueue(ctx)
if msg.Topic == "" && queue != "" {
msg.Topic = queue
}
result := c.handler(ctx, msg)
result.Topic = msg.Topic
result.TaskID = msg.ID
return result
}
func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
if result.Status == "PENDING" && c.opts.respondPendingResult {
return nil
}
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, result.Topic)
if result.Status == "" {
if result.Error != nil {
result.Status = "FAILED"
} else {
result.Status = "SUCCESS"
}
}
bt, _ := json.Marshal(result)
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
if err := c.send(ctx, c.conn, reply); err != nil {
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
}
return nil
}
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
headers := HeadersWithConsumerID(ctx, c.id)
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
}
}
func (c *Consumer) attemptConnect() error {
var err error
delay := c.opts.initialDelay
for i := 0; i < c.opts.maxRetries; i++ {
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
if err == nil {
c.conn = conn
return nil
}
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
time.Sleep(sleepDuration)
delay *= 2
if delay > c.opts.maxBackoff {
delay = c.opts.maxBackoff
}
}
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
}
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(ctx, conn)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
return c.OnMessage(ctx, msg, conn)
}
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
err1 := c.OnClose(ctx, conn)
if err1 != nil {
return err1
}
return err
}
c.OnError(ctx, conn, err)
return err
}
func (c *Consumer) Consume(ctx context.Context) error {
err := c.attemptConnect()
if err != nil {
return err
}
c.pool = NewPool(
c.opts.numOfWorkers,
WithTaskQueueSize(c.opts.queueSize),
WithMaxMemoryLoad(c.opts.maxMemoryLoad),
WithHandler(c.ProcessTask),
WithPoolCallback(c.OnResponse),
WithTaskStorage(c.opts.storage),
)
if err := c.subscribe(ctx, c.queue); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
}
c.pool.Start(c.opts.numOfWorkers)
// Infinite loop to continuously read messages and reconnect if needed.
for {
select {
case <-ctx.Done():
log.Println("Context canceled, stopping consumer.")
return nil
default:
if err := c.readMessage(ctx, c.conn); err != nil {
log.Printf("Error reading message: %v, attempting reconnection...", err)
// Attempt reconnection loop.
for {
if ctx.Err() != nil {
return nil
}
if rErr := c.attemptConnect(); rErr != nil {
log.Printf("Reconnection attempt failed: %v", rErr)
time.Sleep(c.opts.initialDelay)
} else {
break
}
}
if err := c.subscribe(ctx, c.queue); err != nil {
log.Printf("Failed to re-subscribe on reconnection: %v", err)
time.Sleep(c.opts.initialDelay)
}
}
}
}
}
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(ctx, conn)
if err != nil {
return err
}
if msg.Command == consts.SUBSCRIBE_ACK {
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
return nil
}
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
}
func (c *Consumer) Pause(ctx context.Context) error {
return c.operate(ctx, consts.CONSUMER_PAUSED, c.pool.Pause)
}
func (c *Consumer) Resume(ctx context.Context) error {
return c.operate(ctx, consts.CONSUMER_RESUMED, c.pool.Resume)
}
func (c *Consumer) Stop(ctx context.Context) error {
return c.operate(ctx, consts.CONSUMER_STOPPED, c.pool.Stop)
}
func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation func()) error {
if err := c.sendOpsMessage(ctx, cmd); err != nil {
return err
}
poolOperation()
return nil
}
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(cmd, nil, c.queue, headers)
return c.send(ctx, c.conn, msg)
}
func (c *Consumer) Conn() net.Conn {
return c.conn
}
type Publisher struct {
opts *Options
id string
conn net.Conn
connLock sync.Mutex
}
func NewPublisher(id string, opts ...Option) *Publisher {
options := SetupOptions(opts...)
return &Publisher{
id: id,
opts: options,
conn: nil,
}
}
// New method to ensure a persistent connection.
func (p *Publisher) ensureConnection(ctx context.Context) error {
p.connLock.Lock()
defer p.connLock.Unlock()
if p.conn != nil {
return nil
}
var err error
delay := p.opts.initialDelay
for i := 0; i < p.opts.maxRetries; i++ {
var conn net.Conn
conn, err = GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
if err == nil {
p.conn = conn
return nil
}
sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent)
log.Printf("PUBLISHER - ensureConnection failed: %v, attempt %d/%d, retrying in %v...", err, i+1, p.opts.maxRetries, sleepDuration)
time.Sleep(sleepDuration)
delay *= 2
if delay > p.opts.maxBackoff {
delay = p.opts.maxBackoff
}
}
return fmt.Errorf("failed to connect to broker after retries: %w", err)
}
// Modified Publish method that uses the persistent connection.
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
// Ensure connection is established.
if err := p.ensureConnection(ctx); err != nil {
return err
}
delay := p.opts.initialDelay
for i := 0; i < p.opts.maxRetries; i++ {
// Use the persistent connection.
p.connLock.Lock()
conn := p.conn
p.connLock.Unlock()
err := p.send(ctx, queue, task, conn, consts.PUBLISH)
if err == nil {
return nil
}
log.Printf("PUBLISHER - Failed publishing: %v, attempt %d/%d, retrying...", err, i+1, p.opts.maxRetries)
// On error, close and reset the connection.
p.connLock.Lock()
if p.conn != nil {
p.conn.Close()
p.conn = nil
}
p.connLock.Unlock()
sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent)
time.Sleep(sleepDuration)
delay *= 2
if delay > p.opts.maxBackoff {
delay = p.opts.maxBackoff
}
// Ensure connection is re-established.
if err := p.ensureConnection(ctx); err != nil {
return err
}
}
return fmt.Errorf("failed to publish after retries")
}
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
headers := WithHeaders(ctx, map[string]string{
consts.PublisherKey: p.id,
consts.ContentType: consts.TypeJson,
})
if task.ID == "" {
task.ID = NewID()
}
task.CreatedAt = time.Now()
payload, err := json.Marshal(task)
if err != nil {
return err
}
msg := codec.NewMessage(command, payload, queue, headers)
if err := codec.SendMessage(ctx, conn, msg); err != nil {
return err
}
return p.waitForAck(ctx, conn)
}
func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error {
msg, err := codec.ReadMessage(ctx, conn)
if err != nil {
return err
}
if msg.Command == consts.PUBLISH_ACK {
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
return nil
}
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
}
func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result {
msg, err := codec.ReadMessage(ctx, conn)
if err != nil {
return Result{Error: err}
}
if msg.Command == consts.RESPONSE {
var result Result
err = json.Unmarshal(msg.Payload, &result)
return result
}
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
return Result{Error: err}
}
func (p *Publisher) onClose(_ context.Context, conn net.Conn) error {
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
return nil
}
func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
}
func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result {
ctx = SetHeaders(ctx, map[string]string{
consts.AwaitResponseKey: "true",
})
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
if err != nil {
err = fmt.Errorf("failed to connect to broker: %w", err)
return Result{Error: err}
}
defer func() {
_ = conn.Close()
}()
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
resultCh := make(chan Result)
go func() {
defer close(resultCh)
resultCh <- p.waitForResponse(ctx, conn)
}()
finalResult := <-resultCh
return finalResult
}
type Queue struct {
consumers storage.IMap[string, *consumer]
tasks chan *QueuedTask // channel to hold tasks
name string
}
func newQueue(name string, queueSize int) *Queue {
return &Queue{
name: name,
consumers: memory.New[string, *consumer](),
tasks: make(chan *QueuedTask, queueSize), // buffer size for tasks
}
}
func (b *Broker) NewQueue(name string) *Queue {
q := &Queue{
name: name,
tasks: make(chan *QueuedTask, b.opts.queueSize),
consumers: memory.New[string, *consumer](),
}
b.queues.Set(name, q)
// Create DLQ for the queue
dlq := &Queue{
name: name + "_dlq",
tasks: make(chan *QueuedTask, b.opts.queueSize),
consumers: memory.New[string, *consumer](),
}
b.deadLetter.Set(name, dlq)
ctx := context.Background()
go b.dispatchWorker(ctx, q)
go b.dispatchWorker(ctx, dlq)
return q
}
type QueueTask struct {
ctx context.Context
payload *Task
priority int
retryCount int
index int
}
type PriorityQueue []*QueueTask
func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool {
return pq[i].priority > pq[j].priority
}
func (pq PriorityQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
func (pq *PriorityQueue) Push(x interface{}) {
n := len(*pq)
task := x.(*QueueTask)
task.index = n
*pq = append(*pq, task)
}
func (pq *PriorityQueue) Pop() interface{} {
old := *pq
n := len(old)
task := old[n-1]
task.index = -1
*pq = old[0 : n-1]
return task
}
type Task struct {
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
Expiry time.Time `json:"expiry"`
Error error `json:"error"`
ID string `json:"id"`
Topic string `json:"topic"`
Status string `json:"status"`
Payload json.RawMessage `json:"payload"`
dag any
}
func (t *Task) GetFlow() any {
return t.dag
}
func NewTask(id string, payload json.RawMessage, nodeKey string, opts ...TaskOption) *Task {
if id == "" {
id = NewID()
}
task := &Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()}
for _, opt := range opts {
opt(task)
}
return task
}
// TaskOption defines a function type for setting options.
type TaskOption func(*Task)
func WithDAG(dag any) TaskOption {
return func(opts *Task) {
opts.dag = dag
}
}
func (b *Broker) TLSConfig() TLSConfig {
return b.opts.tlsConfig
}
func (b *Broker) SyncMode() bool {
return b.opts.syncMode
}
func (b *Broker) NotifyHandler() func(context.Context, Result) error {
return b.opts.notifyResponse
}
func (b *Broker) SetNotifyHandler(callback Callback) {
b.opts.notifyResponse = callback
}
func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) {
if b.opts.callback != nil {
var result Result
err := json.Unmarshal(msg.Payload, &result)
if err == nil {
for _, callback := range b.opts.callback {
callback(ctx, result)
}
}
}
}