mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-07 00:43:35 +08:00
feat: sig
This commit is contained in:
322
consumer.go
Normal file
322
consumer.go
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/json"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/codec"
|
||||||
|
"github.com/oarkflow/mq/consts"
|
||||||
|
"github.com/oarkflow/mq/jsonparser"
|
||||||
|
"github.com/oarkflow/mq/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Processor interface {
|
||||||
|
ProcessTask(ctx context.Context, msg *Task) Result
|
||||||
|
Consume(ctx context.Context) error
|
||||||
|
Pause(ctx context.Context) error
|
||||||
|
Resume(ctx context.Context) error
|
||||||
|
Stop(ctx context.Context) error
|
||||||
|
Close() error
|
||||||
|
GetKey() string
|
||||||
|
SetKey(key string)
|
||||||
|
GetType() string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Consumer struct {
|
||||||
|
conn net.Conn
|
||||||
|
handler Handler
|
||||||
|
pool *Pool
|
||||||
|
opts *Options
|
||||||
|
id string
|
||||||
|
queue string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer {
|
||||||
|
options := SetupOptions(opts...)
|
||||||
|
return &Consumer{
|
||||||
|
id: id,
|
||||||
|
opts: options,
|
||||||
|
queue: queue,
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||||||
|
return codec.SendMessage(ctx, conn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
|
||||||
|
return codec.ReadMessage(ctx, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) Close() error {
|
||||||
|
c.pool.Stop()
|
||||||
|
err := c.conn.Close()
|
||||||
|
log.Printf("CONSUMER - Connection closed for consumer: %s", c.id)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) GetKey() string {
|
||||||
|
return c.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) GetType() string {
|
||||||
|
return "consumer"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) SetKey(key string) {
|
||||||
|
c.id = key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) Metrics() Metrics {
|
||||||
|
return c.pool.Metrics()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||||
|
headers := HeadersWithConsumerID(ctx, c.id)
|
||||||
|
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
||||||
|
if err := c.send(ctx, c.conn, msg); err != nil {
|
||||||
|
return fmt.Errorf("error while trying to subscribe: %v", err)
|
||||||
|
}
|
||||||
|
return c.waitForAck(ctx, c.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
|
||||||
|
fmt.Println("Consumer closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
||||||
|
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) error {
|
||||||
|
switch msg.Command {
|
||||||
|
case consts.PUBLISH:
|
||||||
|
c.ConsumeMessage(ctx, msg, conn)
|
||||||
|
case consts.CONSUMER_PAUSE:
|
||||||
|
err := c.Pause(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Unable to pause consumer: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
case consts.CONSUMER_RESUME:
|
||||||
|
err := c.Resume(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Unable to resume consumer: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
case consts.CONSUMER_STOP:
|
||||||
|
err := c.Stop(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Unable to stop consumer: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
default:
|
||||||
|
log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||||
|
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
|
||||||
|
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||||
|
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||||
|
if err := c.send(ctx, conn, reply); err != nil {
|
||||||
|
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||||
|
c.sendMessageAck(ctx, msg, conn)
|
||||||
|
if msg.Payload == nil {
|
||||||
|
log.Printf("Received empty message payload")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var task Task
|
||||||
|
err := json.Unmarshal(msg.Payload, &task)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error unmarshalling message: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
||||||
|
if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil {
|
||||||
|
c.sendDenyMessage(ctx, task.ID, msg.Queue, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result {
|
||||||
|
defer RecoverPanic(RecoverTitle)
|
||||||
|
queue, _ := GetQueue(ctx)
|
||||||
|
if msg.Topic == "" && queue != "" {
|
||||||
|
msg.Topic = queue
|
||||||
|
}
|
||||||
|
result := c.handler(ctx, msg)
|
||||||
|
result.Topic = msg.Topic
|
||||||
|
result.TaskID = msg.ID
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
||||||
|
if result.Status == "PENDING" && c.opts.respondPendingResult {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, result.Topic)
|
||||||
|
if result.Status == "" {
|
||||||
|
if result.Error != nil {
|
||||||
|
result.Status = "FAILED"
|
||||||
|
} else {
|
||||||
|
result.Status = "SUCCESS"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bt, _ := json.Marshal(result)
|
||||||
|
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
||||||
|
if err := c.send(ctx, c.conn, reply); err != nil {
|
||||||
|
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
|
||||||
|
headers := HeadersWithConsumerID(ctx, c.id)
|
||||||
|
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
||||||
|
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
|
||||||
|
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) attemptConnect() error {
|
||||||
|
var err error
|
||||||
|
delay := c.opts.initialDelay
|
||||||
|
for i := 0; i < c.opts.maxRetries; i++ {
|
||||||
|
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
||||||
|
if err == nil {
|
||||||
|
c.conn = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
||||||
|
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
||||||
|
time.Sleep(sleepDuration)
|
||||||
|
delay *= 2
|
||||||
|
if delay > c.opts.maxBackoff {
|
||||||
|
delay = c.opts.maxBackoff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||||||
|
msg, err := c.receive(ctx, conn)
|
||||||
|
if err == nil {
|
||||||
|
ctx = SetHeaders(ctx, msg.Headers)
|
||||||
|
return c.OnMessage(ctx, msg, conn)
|
||||||
|
}
|
||||||
|
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||||
|
err1 := c.OnClose(ctx, conn)
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.OnError(ctx, conn, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) Consume(ctx context.Context) error {
|
||||||
|
err := c.attemptConnect()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.pool = NewPool(
|
||||||
|
c.opts.numOfWorkers,
|
||||||
|
WithTaskQueueSize(c.opts.queueSize),
|
||||||
|
WithMaxMemoryLoad(c.opts.maxMemoryLoad),
|
||||||
|
WithHandler(c.ProcessTask),
|
||||||
|
WithPoolCallback(c.OnResponse),
|
||||||
|
WithTaskStorage(c.opts.storage),
|
||||||
|
)
|
||||||
|
if err := c.subscribe(ctx, c.queue); err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||||||
|
}
|
||||||
|
c.pool.Start(c.opts.numOfWorkers)
|
||||||
|
// Infinite loop to continuously read messages and reconnect if needed.
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Println("Context canceled, stopping consumer.")
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
if c.opts.ConsumerRateLimiter != nil {
|
||||||
|
c.opts.ConsumerRateLimiter.Wait()
|
||||||
|
}
|
||||||
|
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||||
|
log.Printf("Error reading message: %v, attempting reconnection...", err)
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rErr := c.attemptConnect(); rErr != nil {
|
||||||
|
log.Printf("Reconnection attempt failed: %v", rErr)
|
||||||
|
time.Sleep(c.opts.initialDelay)
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := c.subscribe(ctx, c.queue); err != nil {
|
||||||
|
log.Printf("Failed to re-subscribe on reconnection: %v", err)
|
||||||
|
time.Sleep(c.opts.initialDelay)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
|
||||||
|
msg, err := c.receive(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.SUBSCRIBE_ACK {
|
||||||
|
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) Pause(ctx context.Context) error {
|
||||||
|
return c.operate(ctx, consts.CONSUMER_PAUSED, c.pool.Pause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) Resume(ctx context.Context) error {
|
||||||
|
return c.operate(ctx, consts.CONSUMER_RESUMED, c.pool.Resume)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) Stop(ctx context.Context) error {
|
||||||
|
return c.operate(ctx, consts.CONSUMER_STOPPED, c.pool.Stop)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation func()) error {
|
||||||
|
if err := c.sendOpsMessage(ctx, cmd); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
poolOperation()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
||||||
|
headers := HeadersWithConsumerID(ctx, c.id)
|
||||||
|
msg := codec.NewMessage(cmd, nil, c.queue, headers)
|
||||||
|
return c.send(ctx, c.conn, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) Conn() net.Conn {
|
||||||
|
return c.conn
|
||||||
|
}
|
567
mq.go
567
mq.go
@@ -7,7 +7,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/errors"
|
"github.com/oarkflow/errors"
|
||||||
@@ -327,6 +326,12 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Broker) AdjustConsumerWorkers(noOfWorkers int, consumerID ...string) {
|
||||||
|
b.consumers.ForEach(func(_ string, c *consumer) bool {
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
|
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
|
||||||
consumerID, _ := GetConsumerID(ctx)
|
consumerID, _ := GetConsumerID(ctx)
|
||||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||||
@@ -577,11 +582,11 @@ func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
|
func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, payload []byte, queues ...string) {
|
||||||
fn := func(queue *Queue) {
|
fn := func(queue *Queue) {
|
||||||
con, ok := queue.consumers.Get(consumerID)
|
con, ok := queue.consumers.Get(consumerID)
|
||||||
if ok {
|
if ok {
|
||||||
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
ack := codec.NewMessage(cmd, payload, queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||||||
err := b.send(ctx, con.conn, ack)
|
err := b.send(ctx, con.conn, ack)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
con.state = state
|
con.state = state
|
||||||
@@ -603,15 +608,15 @@ func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state const
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) {
|
func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||||
b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
|
b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, utils.ToByte("{}"), queues...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) {
|
func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||||
b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
|
b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, utils.ToByte("{}"), queues...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) {
|
func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||||
b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
|
b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, utils.ToByte("{}"), queues...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||||
@@ -721,486 +726,6 @@ func (b *Broker) SetURL(url string) {
|
|||||||
b.opts.brokerAddr = url
|
b.opts.brokerAddr = url
|
||||||
}
|
}
|
||||||
|
|
||||||
type Processor interface {
|
|
||||||
ProcessTask(ctx context.Context, msg *Task) Result
|
|
||||||
Consume(ctx context.Context) error
|
|
||||||
Pause(ctx context.Context) error
|
|
||||||
Resume(ctx context.Context) error
|
|
||||||
Stop(ctx context.Context) error
|
|
||||||
Close() error
|
|
||||||
GetKey() string
|
|
||||||
SetKey(key string)
|
|
||||||
GetType() string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Consumer struct {
|
|
||||||
conn net.Conn
|
|
||||||
handler Handler
|
|
||||||
pool *Pool
|
|
||||||
opts *Options
|
|
||||||
id string
|
|
||||||
queue string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer {
|
|
||||||
options := SetupOptions(opts...)
|
|
||||||
return &Consumer{
|
|
||||||
id: id,
|
|
||||||
opts: options,
|
|
||||||
queue: queue,
|
|
||||||
handler: handler,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
|
||||||
return codec.SendMessage(ctx, conn, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
|
|
||||||
return codec.ReadMessage(ctx, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) Close() error {
|
|
||||||
c.pool.Stop()
|
|
||||||
err := c.conn.Close()
|
|
||||||
log.Printf("CONSUMER - Connection closed for consumer: %s", c.id)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) GetKey() string {
|
|
||||||
return c.id
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) GetType() string {
|
|
||||||
return "consumer"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) SetKey(key string) {
|
|
||||||
c.id = key
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) Metrics() Metrics {
|
|
||||||
return c.pool.Metrics()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
|
||||||
headers := HeadersWithConsumerID(ctx, c.id)
|
|
||||||
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
|
||||||
if err := c.send(ctx, c.conn, msg); err != nil {
|
|
||||||
return fmt.Errorf("error while trying to subscribe: %v", err)
|
|
||||||
}
|
|
||||||
return c.waitForAck(ctx, c.conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
|
|
||||||
fmt.Println("Consumer closed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
|
||||||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) error {
|
|
||||||
switch msg.Command {
|
|
||||||
case consts.PUBLISH:
|
|
||||||
c.ConsumeMessage(ctx, msg, conn)
|
|
||||||
case consts.CONSUMER_PAUSE:
|
|
||||||
err := c.Pause(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Unable to pause consumer: %v", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
case consts.CONSUMER_RESUME:
|
|
||||||
err := c.Resume(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Unable to resume consumer: %v", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
case consts.CONSUMER_STOP:
|
|
||||||
err := c.Stop(ctx)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Unable to stop consumer: %v", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
default:
|
|
||||||
log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
|
||||||
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
|
|
||||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
|
||||||
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
|
||||||
if err := c.send(ctx, conn, reply); err != nil {
|
|
||||||
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
|
||||||
c.sendMessageAck(ctx, msg, conn)
|
|
||||||
if msg.Payload == nil {
|
|
||||||
log.Printf("Received empty message payload")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var task Task
|
|
||||||
err := json.Unmarshal(msg.Payload, &task)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error unmarshalling message: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
|
||||||
if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil {
|
|
||||||
c.sendDenyMessage(ctx, task.ID, msg.Queue, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result {
|
|
||||||
defer RecoverPanic(RecoverTitle)
|
|
||||||
queue, _ := GetQueue(ctx)
|
|
||||||
if msg.Topic == "" && queue != "" {
|
|
||||||
msg.Topic = queue
|
|
||||||
}
|
|
||||||
result := c.handler(ctx, msg)
|
|
||||||
result.Topic = msg.Topic
|
|
||||||
result.TaskID = msg.ID
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
|
||||||
if result.Status == "PENDING" && c.opts.respondPendingResult {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, result.Topic)
|
|
||||||
if result.Status == "" {
|
|
||||||
if result.Error != nil {
|
|
||||||
result.Status = "FAILED"
|
|
||||||
} else {
|
|
||||||
result.Status = "SUCCESS"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bt, _ := json.Marshal(result)
|
|
||||||
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
|
||||||
if err := c.send(ctx, c.conn, reply); err != nil {
|
|
||||||
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
|
|
||||||
headers := HeadersWithConsumerID(ctx, c.id)
|
|
||||||
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
|
||||||
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
|
|
||||||
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) attemptConnect() error {
|
|
||||||
var err error
|
|
||||||
delay := c.opts.initialDelay
|
|
||||||
for i := 0; i < c.opts.maxRetries; i++ {
|
|
||||||
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
|
||||||
if err == nil {
|
|
||||||
c.conn = conn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
|
||||||
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
|
||||||
time.Sleep(sleepDuration)
|
|
||||||
delay *= 2
|
|
||||||
if delay > c.opts.maxBackoff {
|
|
||||||
delay = c.opts.maxBackoff
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
|
||||||
msg, err := c.receive(ctx, conn)
|
|
||||||
if err == nil {
|
|
||||||
ctx = SetHeaders(ctx, msg.Headers)
|
|
||||||
return c.OnMessage(ctx, msg, conn)
|
|
||||||
}
|
|
||||||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
|
||||||
err1 := c.OnClose(ctx, conn)
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.OnError(ctx, conn, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) Consume(ctx context.Context) error {
|
|
||||||
err := c.attemptConnect()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.pool = NewPool(
|
|
||||||
c.opts.numOfWorkers,
|
|
||||||
WithTaskQueueSize(c.opts.queueSize),
|
|
||||||
WithMaxMemoryLoad(c.opts.maxMemoryLoad),
|
|
||||||
WithHandler(c.ProcessTask),
|
|
||||||
WithPoolCallback(c.OnResponse),
|
|
||||||
WithTaskStorage(c.opts.storage),
|
|
||||||
)
|
|
||||||
if err := c.subscribe(ctx, c.queue); err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
|
||||||
}
|
|
||||||
c.pool.Start(c.opts.numOfWorkers)
|
|
||||||
// Infinite loop to continuously read messages and reconnect if needed.
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Println("Context canceled, stopping consumer.")
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
if c.opts.ConsumerRateLimiter != nil {
|
|
||||||
c.opts.ConsumerRateLimiter.Wait()
|
|
||||||
}
|
|
||||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
|
||||||
log.Printf("Error reading message: %v, attempting reconnection...", err)
|
|
||||||
// Attempt reconnection loop.
|
|
||||||
for {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if rErr := c.attemptConnect(); rErr != nil {
|
|
||||||
log.Printf("Reconnection attempt failed: %v", rErr)
|
|
||||||
time.Sleep(c.opts.initialDelay)
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := c.subscribe(ctx, c.queue); err != nil {
|
|
||||||
log.Printf("Failed to re-subscribe on reconnection: %v", err)
|
|
||||||
time.Sleep(c.opts.initialDelay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
|
|
||||||
msg, err := c.receive(ctx, conn)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Command == consts.SUBSCRIBE_ACK {
|
|
||||||
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) Pause(ctx context.Context) error {
|
|
||||||
return c.operate(ctx, consts.CONSUMER_PAUSED, c.pool.Pause)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) Resume(ctx context.Context) error {
|
|
||||||
return c.operate(ctx, consts.CONSUMER_RESUMED, c.pool.Resume)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) Stop(ctx context.Context) error {
|
|
||||||
return c.operate(ctx, consts.CONSUMER_STOPPED, c.pool.Stop)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation func()) error {
|
|
||||||
if err := c.sendOpsMessage(ctx, cmd); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
poolOperation()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
|
||||||
headers := HeadersWithConsumerID(ctx, c.id)
|
|
||||||
msg := codec.NewMessage(cmd, nil, c.queue, headers)
|
|
||||||
return c.send(ctx, c.conn, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) Conn() net.Conn {
|
|
||||||
return c.conn
|
|
||||||
}
|
|
||||||
|
|
||||||
type Publisher struct {
|
|
||||||
opts *Options
|
|
||||||
id string
|
|
||||||
conn net.Conn
|
|
||||||
connLock sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPublisher(id string, opts ...Option) *Publisher {
|
|
||||||
options := SetupOptions(opts...)
|
|
||||||
return &Publisher{
|
|
||||||
id: id,
|
|
||||||
opts: options,
|
|
||||||
conn: nil,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// New method to ensure a persistent connection.
|
|
||||||
func (p *Publisher) ensureConnection(ctx context.Context) error {
|
|
||||||
p.connLock.Lock()
|
|
||||||
defer p.connLock.Unlock()
|
|
||||||
if p.conn != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
delay := p.opts.initialDelay
|
|
||||||
for i := 0; i < p.opts.maxRetries; i++ {
|
|
||||||
var conn net.Conn
|
|
||||||
conn, err = GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
|
||||||
if err == nil {
|
|
||||||
p.conn = conn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent)
|
|
||||||
log.Printf("PUBLISHER - ensureConnection failed: %v, attempt %d/%d, retrying in %v...", err, i+1, p.opts.maxRetries, sleepDuration)
|
|
||||||
time.Sleep(sleepDuration)
|
|
||||||
delay *= 2
|
|
||||||
if delay > p.opts.maxBackoff {
|
|
||||||
delay = p.opts.maxBackoff
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to connect to broker after retries: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Modified Publish method that uses the persistent connection.
|
|
||||||
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
|
|
||||||
// Ensure connection is established.
|
|
||||||
if err := p.ensureConnection(ctx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
delay := p.opts.initialDelay
|
|
||||||
for i := 0; i < p.opts.maxRetries; i++ {
|
|
||||||
// Use the persistent connection.
|
|
||||||
p.connLock.Lock()
|
|
||||||
conn := p.conn
|
|
||||||
p.connLock.Unlock()
|
|
||||||
err := p.send(ctx, queue, task, conn, consts.PUBLISH)
|
|
||||||
if err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
log.Printf("PUBLISHER - Failed publishing: %v, attempt %d/%d, retrying...", err, i+1, p.opts.maxRetries)
|
|
||||||
// On error, close and reset the connection.
|
|
||||||
p.connLock.Lock()
|
|
||||||
if p.conn != nil {
|
|
||||||
p.conn.Close()
|
|
||||||
p.conn = nil
|
|
||||||
}
|
|
||||||
p.connLock.Unlock()
|
|
||||||
sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent)
|
|
||||||
time.Sleep(sleepDuration)
|
|
||||||
delay *= 2
|
|
||||||
if delay > p.opts.maxBackoff {
|
|
||||||
delay = p.opts.maxBackoff
|
|
||||||
}
|
|
||||||
// Ensure connection is re-established.
|
|
||||||
if err := p.ensureConnection(ctx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("failed to publish after retries")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
|
||||||
headers := WithHeaders(ctx, map[string]string{
|
|
||||||
consts.PublisherKey: p.id,
|
|
||||||
consts.ContentType: consts.TypeJson,
|
|
||||||
})
|
|
||||||
if task.ID == "" {
|
|
||||||
task.ID = NewID()
|
|
||||||
}
|
|
||||||
task.CreatedAt = time.Now()
|
|
||||||
payload, err := json.Marshal(task)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg := codec.NewMessage(command, payload, queue, headers)
|
|
||||||
if err := codec.SendMessage(ctx, conn, msg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.waitForAck(ctx, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error {
|
|
||||||
msg, err := codec.ReadMessage(ctx, conn)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Command == consts.PUBLISH_ACK {
|
|
||||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
|
||||||
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result {
|
|
||||||
msg, err := codec.ReadMessage(ctx, conn)
|
|
||||||
if err != nil {
|
|
||||||
return Result{Error: err}
|
|
||||||
}
|
|
||||||
if msg.Command == consts.RESPONSE {
|
|
||||||
var result Result
|
|
||||||
err = json.Unmarshal(msg.Payload, &result)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
|
|
||||||
return Result{Error: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) onClose(_ context.Context, conn net.Conn) error {
|
|
||||||
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) {
|
|
||||||
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result {
|
|
||||||
ctx = SetHeaders(ctx, map[string]string{
|
|
||||||
consts.AwaitResponseKey: "true",
|
|
||||||
})
|
|
||||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to connect to broker: %w", err)
|
|
||||||
return Result{Error: err}
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
_ = conn.Close()
|
|
||||||
}()
|
|
||||||
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
|
||||||
resultCh := make(chan Result)
|
|
||||||
go func() {
|
|
||||||
defer close(resultCh)
|
|
||||||
resultCh <- p.waitForResponse(ctx, conn)
|
|
||||||
}()
|
|
||||||
finalResult := <-resultCh
|
|
||||||
return finalResult
|
|
||||||
}
|
|
||||||
|
|
||||||
type Queue struct {
|
|
||||||
consumers storage.IMap[string, *consumer]
|
|
||||||
tasks chan *QueuedTask // channel to hold tasks
|
|
||||||
name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newQueue(name string, queueSize int) *Queue {
|
|
||||||
return &Queue{
|
|
||||||
name: name,
|
|
||||||
consumers: memory.New[string, *consumer](),
|
|
||||||
tasks: make(chan *QueuedTask, queueSize), // buffer size for tasks
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) NewQueue(name string) *Queue {
|
func (b *Broker) NewQueue(name string) *Queue {
|
||||||
q := &Queue{
|
q := &Queue{
|
||||||
name: name,
|
name: name,
|
||||||
@@ -1222,76 +747,6 @@ func (b *Broker) NewQueue(name string) *Queue {
|
|||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueueTask struct {
|
|
||||||
ctx context.Context
|
|
||||||
payload *Task
|
|
||||||
priority int
|
|
||||||
retryCount int
|
|
||||||
index int
|
|
||||||
}
|
|
||||||
|
|
||||||
type PriorityQueue []*QueueTask
|
|
||||||
|
|
||||||
func (pq PriorityQueue) Len() int { return len(pq) }
|
|
||||||
func (pq PriorityQueue) Less(i, j int) bool {
|
|
||||||
return pq[i].priority > pq[j].priority
|
|
||||||
}
|
|
||||||
func (pq PriorityQueue) Swap(i, j int) {
|
|
||||||
pq[i], pq[j] = pq[j], pq[i]
|
|
||||||
pq[i].index = i
|
|
||||||
pq[j].index = j
|
|
||||||
}
|
|
||||||
func (pq *PriorityQueue) Push(x interface{}) {
|
|
||||||
n := len(*pq)
|
|
||||||
task := x.(*QueueTask)
|
|
||||||
task.index = n
|
|
||||||
*pq = append(*pq, task)
|
|
||||||
}
|
|
||||||
func (pq *PriorityQueue) Pop() interface{} {
|
|
||||||
old := *pq
|
|
||||||
n := len(old)
|
|
||||||
task := old[n-1]
|
|
||||||
task.index = -1
|
|
||||||
*pq = old[0 : n-1]
|
|
||||||
return task
|
|
||||||
}
|
|
||||||
|
|
||||||
type Task struct {
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
ProcessedAt time.Time `json:"processed_at"`
|
|
||||||
Expiry time.Time `json:"expiry"`
|
|
||||||
Error error `json:"error"`
|
|
||||||
ID string `json:"id"`
|
|
||||||
Topic string `json:"topic"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Payload json.RawMessage `json:"payload"`
|
|
||||||
dag any
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Task) GetFlow() any {
|
|
||||||
return t.dag
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTask(id string, payload json.RawMessage, nodeKey string, opts ...TaskOption) *Task {
|
|
||||||
if id == "" {
|
|
||||||
id = NewID()
|
|
||||||
}
|
|
||||||
task := &Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(task)
|
|
||||||
}
|
|
||||||
return task
|
|
||||||
}
|
|
||||||
|
|
||||||
// TaskOption defines a function type for setting options.
|
|
||||||
type TaskOption func(*Task)
|
|
||||||
|
|
||||||
func WithDAG(dag any) TaskOption {
|
|
||||||
return func(opts *Task) {
|
|
||||||
opts.dag = dag
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) TLSConfig() TLSConfig {
|
func (b *Broker) TLSConfig() TLSConfig {
|
||||||
return b.opts.tlsConfig
|
return b.opts.tlsConfig
|
||||||
}
|
}
|
||||||
|
@@ -280,3 +280,12 @@ func DisableConsumerRateLimit() Option {
|
|||||||
opts.ConsumerRateLimiter = nil
|
opts.ConsumerRateLimiter = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TaskOption defines a function type for setting options.
|
||||||
|
type TaskOption func(*Task)
|
||||||
|
|
||||||
|
func WithDAG(dag any) TaskOption {
|
||||||
|
return func(opts *Task) {
|
||||||
|
opts.dag = dag
|
||||||
|
}
|
||||||
|
}
|
||||||
|
177
publisher.go
Normal file
177
publisher.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/json"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/codec"
|
||||||
|
"github.com/oarkflow/mq/consts"
|
||||||
|
"github.com/oarkflow/mq/jsonparser"
|
||||||
|
"github.com/oarkflow/mq/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Publisher struct {
|
||||||
|
opts *Options
|
||||||
|
id string
|
||||||
|
conn net.Conn
|
||||||
|
connLock sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPublisher(id string, opts ...Option) *Publisher {
|
||||||
|
options := SetupOptions(opts...)
|
||||||
|
return &Publisher{
|
||||||
|
id: id,
|
||||||
|
opts: options,
|
||||||
|
conn: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// New method to ensure a persistent connection.
|
||||||
|
func (p *Publisher) ensureConnection(ctx context.Context) error {
|
||||||
|
p.connLock.Lock()
|
||||||
|
defer p.connLock.Unlock()
|
||||||
|
if p.conn != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
delay := p.opts.initialDelay
|
||||||
|
for i := 0; i < p.opts.maxRetries; i++ {
|
||||||
|
var conn net.Conn
|
||||||
|
conn, err = GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||||
|
if err == nil {
|
||||||
|
p.conn = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent)
|
||||||
|
log.Printf("PUBLISHER - ensureConnection failed: %v, attempt %d/%d, retrying in %v...", err, i+1, p.opts.maxRetries, sleepDuration)
|
||||||
|
time.Sleep(sleepDuration)
|
||||||
|
delay *= 2
|
||||||
|
if delay > p.opts.maxBackoff {
|
||||||
|
delay = p.opts.maxBackoff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to connect to broker after retries: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish method that uses the persistent connection.
|
||||||
|
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
|
||||||
|
// Ensure connection is established.
|
||||||
|
if err := p.ensureConnection(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
delay := p.opts.initialDelay
|
||||||
|
for i := 0; i < p.opts.maxRetries; i++ {
|
||||||
|
// Use the persistent connection.
|
||||||
|
p.connLock.Lock()
|
||||||
|
conn := p.conn
|
||||||
|
p.connLock.Unlock()
|
||||||
|
err := p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Printf("PUBLISHER - Failed publishing: %v, attempt %d/%d, retrying...", err, i+1, p.opts.maxRetries)
|
||||||
|
// On error, close and reset the connection.
|
||||||
|
p.connLock.Lock()
|
||||||
|
if p.conn != nil {
|
||||||
|
p.conn.Close()
|
||||||
|
p.conn = nil
|
||||||
|
}
|
||||||
|
p.connLock.Unlock()
|
||||||
|
sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent)
|
||||||
|
time.Sleep(sleepDuration)
|
||||||
|
delay *= 2
|
||||||
|
if delay > p.opts.maxBackoff {
|
||||||
|
delay = p.opts.maxBackoff
|
||||||
|
}
|
||||||
|
// Ensure connection is re-established.
|
||||||
|
if err := p.ensureConnection(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to publish after retries")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
||||||
|
headers := WithHeaders(ctx, map[string]string{
|
||||||
|
consts.PublisherKey: p.id,
|
||||||
|
consts.ContentType: consts.TypeJson,
|
||||||
|
})
|
||||||
|
if task.ID == "" {
|
||||||
|
task.ID = NewID()
|
||||||
|
}
|
||||||
|
task.CreatedAt = time.Now()
|
||||||
|
payload, err := json.Marshal(task)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
msg := codec.NewMessage(command, payload, queue, headers)
|
||||||
|
if err := codec.SendMessage(ctx, conn, msg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.waitForAck(ctx, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.PUBLISH_ACK {
|
||||||
|
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||||
|
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result {
|
||||||
|
msg, err := codec.ReadMessage(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return Result{Error: err}
|
||||||
|
}
|
||||||
|
if msg.Command == consts.RESPONSE {
|
||||||
|
var result Result
|
||||||
|
err = json.Unmarshal(msg.Payload, &result)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
|
||||||
|
return Result{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) onClose(_ context.Context, conn net.Conn) error {
|
||||||
|
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) {
|
||||||
|
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result {
|
||||||
|
ctx = SetHeaders(ctx, map[string]string{
|
||||||
|
consts.AwaitResponseKey: "true",
|
||||||
|
})
|
||||||
|
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to connect to broker: %w", err)
|
||||||
|
return Result{Error: err}
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.Close()
|
||||||
|
}()
|
||||||
|
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||||
|
resultCh := make(chan Result)
|
||||||
|
go func() {
|
||||||
|
defer close(resultCh)
|
||||||
|
resultCh <- p.waitForResponse(ctx, conn)
|
||||||
|
}()
|
||||||
|
finalResult := <-resultCh
|
||||||
|
return finalResult
|
||||||
|
}
|
86
task.go
Normal file
86
task.go
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
package mq
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/json"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/storage"
|
||||||
|
"github.com/oarkflow/mq/storage/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Queue struct {
|
||||||
|
consumers storage.IMap[string, *consumer]
|
||||||
|
tasks chan *QueuedTask // channel to hold tasks
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newQueue(name string, queueSize int) *Queue {
|
||||||
|
return &Queue{
|
||||||
|
name: name,
|
||||||
|
consumers: memory.New[string, *consumer](),
|
||||||
|
tasks: make(chan *QueuedTask, queueSize), // buffer size for tasks
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueueTask struct {
|
||||||
|
ctx context.Context
|
||||||
|
payload *Task
|
||||||
|
priority int
|
||||||
|
retryCount int
|
||||||
|
index int
|
||||||
|
}
|
||||||
|
|
||||||
|
type PriorityQueue []*QueueTask
|
||||||
|
|
||||||
|
func (pq PriorityQueue) Len() int { return len(pq) }
|
||||||
|
func (pq PriorityQueue) Less(i, j int) bool {
|
||||||
|
return pq[i].priority > pq[j].priority
|
||||||
|
}
|
||||||
|
func (pq PriorityQueue) Swap(i, j int) {
|
||||||
|
pq[i], pq[j] = pq[j], pq[i]
|
||||||
|
pq[i].index = i
|
||||||
|
pq[j].index = j
|
||||||
|
}
|
||||||
|
func (pq *PriorityQueue) Push(x interface{}) {
|
||||||
|
n := len(*pq)
|
||||||
|
task := x.(*QueueTask)
|
||||||
|
task.index = n
|
||||||
|
*pq = append(*pq, task)
|
||||||
|
}
|
||||||
|
func (pq *PriorityQueue) Pop() interface{} {
|
||||||
|
old := *pq
|
||||||
|
n := len(old)
|
||||||
|
task := old[n-1]
|
||||||
|
task.index = -1
|
||||||
|
*pq = old[0 : n-1]
|
||||||
|
return task
|
||||||
|
}
|
||||||
|
|
||||||
|
type Task struct {
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ProcessedAt time.Time `json:"processed_at"`
|
||||||
|
Expiry time.Time `json:"expiry"`
|
||||||
|
Error error `json:"error"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Payload json.RawMessage `json:"payload"`
|
||||||
|
dag any
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Task) GetFlow() any {
|
||||||
|
return t.dag
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTask(id string, payload json.RawMessage, nodeKey string, opts ...TaskOption) *Task {
|
||||||
|
if id == "" {
|
||||||
|
id = NewID()
|
||||||
|
}
|
||||||
|
task := &Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(task)
|
||||||
|
}
|
||||||
|
return task
|
||||||
|
}
|
Reference in New Issue
Block a user