feat: separate broker

This commit is contained in:
Oarkflow
2024-10-05 18:34:24 +05:45
parent 324c6f691e
commit 40f5d0dad3
18 changed files with 508 additions and 1593 deletions

410
broker.go
View File

@@ -4,37 +4,35 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"strings"
"time"
"github.com/oarkflow/xsync"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/jsonparser"
"github.com/oarkflow/mq/utils"
)
type QueuedTask struct {
Message *codec.Message
RetryCount int
}
type consumer struct {
id string
conn net.Conn
}
func (p *consumer) send(ctx context.Context, cmd any) error {
return Write(ctx, p.conn, cmd)
}
type publisher struct {
id string
conn net.Conn
}
func (p *publisher) send(ctx context.Context, cmd any) error {
return Write(ctx, p.conn, cmd)
}
type Handler func(context.Context, Task) Result
type Broker struct {
queues xsync.IMap[string, *Queue]
consumers xsync.IMap[string, *consumer]
@@ -42,100 +40,17 @@ type Broker struct {
opts Options
}
type Queue struct {
name string
consumers xsync.IMap[string, *consumer]
messages xsync.IMap[string, *Task]
deferred xsync.IMap[string, *Task]
}
func newQueue(name string) *Queue {
return &Queue{
name: name,
consumers: xsync.NewMap[string, *consumer](),
messages: xsync.NewMap[string, *Task](),
deferred: xsync.NewMap[string, *Task](),
}
}
func (queue *Queue) send(ctx context.Context, cmd any) {
queue.consumers.ForEach(func(_ string, client *consumer) bool {
err := client.send(ctx, cmd)
if err != nil {
return false
}
return true
})
}
type Task struct {
ID string `json:"id"`
Payload json.RawMessage `json:"payload"`
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
CurrentQueue string `json:"current_queue"`
Status string `json:"status"`
Error error `json:"error"`
}
type Command struct {
ID string `json:"id"`
Command consts.CMD `json:"command"`
Queue string `json:"queue"`
MessageID string `json:"message_id"`
Payload json.RawMessage `json:"payload,omitempty"` // Used for carrying the task payload
Error string `json:"error,omitempty"`
}
type Result struct {
Command string `json:"command"`
Payload json.RawMessage `json:"payload"`
Queue string `json:"queue"`
MessageID string `json:"message_id"`
Error error `json:"error"`
Status string `json:"status"`
}
func NewBroker(opts ...Option) *Broker {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
b := &Broker{
options := setupOptions(opts...)
return &Broker{
queues: xsync.NewMap[string, *Queue](),
publishers: xsync.NewMap[string, *publisher](),
consumers: xsync.NewMap[string, *consumer](),
opts: options,
}
b.opts = defaultHandlers(options, b.onMessage, b.onClose, b.onError)
return b
}
func (b *Broker) Send(ctx context.Context, cmd Command) error {
queue, ok := b.queues.Get(cmd.Queue)
if !ok || queue == nil {
return errors.New("invalid queue or not exists")
}
queue.send(ctx, cmd)
return nil
}
func (b *Broker) TLSConfig() TLSConfig {
return b.opts.tlsConfig
}
func (b *Broker) SyncMode() bool {
return b.opts.syncMode
}
func (b *Broker) sendToPublisher(ctx context.Context, publisherID string, result Result) error {
pub, ok := b.publishers.Get(publisherID)
if !ok {
return nil
}
return pub.send(ctx, result)
}
func (b *Broker) onClose(ctx context.Context, _ net.Conn) error {
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
consumerID, ok := GetConsumerID(ctx)
if ok && consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
@@ -157,11 +72,94 @@ func (b *Broker) onClose(ctx context.Context, _ net.Conn) error {
return nil
}
func (b *Broker) onError(_ context.Context, conn net.Conn, err error) {
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
}
// Start the broker server with optional TLS support
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
switch msg.Command {
case consts.PUBLISH:
b.PublishHandler(ctx, conn, msg)
case consts.SUBSCRIBE:
b.SubscribeHandler(ctx, conn, msg)
case consts.MESSAGE_RESPONSE:
b.MessageResponseHandler(ctx, msg)
case consts.MESSAGE_ACK:
b.MessageAck(ctx, msg)
}
}
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
}
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
msg.Command = consts.RESPONSE
headers, ok := GetHeaders(ctx)
if !ok {
return
}
b.HandleCallback(ctx, msg)
awaitResponse, ok := headers[consts.AwaitResponseKey]
if !(ok && awaitResponse == "true") {
return
}
publisherID, exists := headers[consts.PublisherKey]
if !exists {
return
}
con, ok := b.publishers.Get(publisherID)
if !ok {
return
}
err := b.send(con.conn, msg)
if err != nil {
panic(err)
}
}
func (b *Broker) Publish(ctx context.Context, task Task, queue string) error {
headers, _ := GetHeaders(ctx)
payload, _ := json.Marshal(task)
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers)
b.broadcastToConsumers(ctx, msg)
return nil
}
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
pub := b.addPublisher(ctx, msg.Queue, conn)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
}
b.broadcastToConsumers(ctx, msg)
go func() {
select {
case <-ctx.Done():
b.publishers.Del(pub.id)
}
}()
}
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
consumerID := b.addConsumer(ctx, msg.Queue, conn)
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
}
go func() {
select {
case <-ctx.Done():
b.removeConsumer(msg.Queue, consumerID)
}
}()
}
func (b *Broker) Start(ctx context.Context) error {
var listener net.Listener
var err error
@@ -178,113 +176,61 @@ func (b *Broker) Start(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to start TLS listener: %v", err)
}
log.Println("TLS server started on", b.opts.brokerAddr)
log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr)
} else {
listener, err = net.Listen("tcp", b.opts.brokerAddr)
if err != nil {
return fmt.Errorf("failed to start TCP listener: %v", err)
}
log.Println("TCP server started on", b.opts.brokerAddr)
log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr)
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println("Error accepting connection:", err)
b.OnError(ctx, conn, err)
continue
}
go ReadFromConn(ctx, conn, Handlers{
MessageHandler: b.opts.messageHandler,
CloseHandler: b.opts.closeHandler,
ErrorHandler: b.opts.errorHandler,
})
}
}
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) Result {
queue, task, err := b.AddMessageToQueue(&message, queueName)
go func(c net.Conn) {
defer c.Close()
for {
err := b.readMessage(ctx, c)
if err != nil {
return Result{Error: err}
break
}
result := Result{
Command: "PUBLISH",
Payload: message.Payload,
Queue: queueName,
MessageID: task.ID,
}
if queue.consumers.Size() == 0 {
queue.deferred.Set(NewID(), &message)
fmt.Println("task deferred as no consumers are connected", queueName)
return result
}(conn)
}
queue.send(ctx, message)
return result
}
func (b *Broker) NewQueue(qName string) *Queue {
q, ok := b.queues.Get(qName)
if ok {
return q
}
q = newQueue(qName)
b.queues.Set(qName, q)
return q
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
}
func (b *Broker) AddMessageToQueue(task *Task, queueName string) (*Queue, *Task, error) {
queue := b.NewQueue(queueName)
if task.ID == "" {
task.ID = NewID()
}
if queueName != "" {
task.CurrentQueue = queueName
}
task.CreatedAt = time.Now()
queue.messages.Set(task.ID, task)
return queue, task, nil
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
}
func (b *Broker) HandleProcessedMessage(ctx context.Context, result Result) error {
publisherID, ok := GetPublisherID(ctx)
if ok && publisherID != "" {
err := b.sendToPublisher(ctx, publisherID, result)
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
if queue, ok := b.queues.Get(msg.Queue); ok {
task := &QueuedTask{Message: msg, RetryCount: 0}
queue.tasks <- task
}
}
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
msg, err := b.receive(conn)
if err != nil {
return err
}
}
for _, callback := range b.opts.callback {
if callback != nil {
rs := callback(ctx, result)
if rs.Error != nil {
return rs.Error
}
}
}
if msg.Command == consts.MESSAGE_ACK {
log.Println("Received CONSUMER_ACK: Subscribed successfully")
return nil
}
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
}
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
consumerID, ok := GetConsumerID(ctx)
defer func() {
cmd := Command{
Command: consts.SUBSCRIBE_ACK,
Queue: queueName,
Error: "",
}
Write(ctx, conn, cmd)
log.Printf("Consumer %s joined server on queue %s", consumerID, queueName)
}()
q, ok := b.queues.Get(queueName)
if !ok {
q = b.NewQueue(queueName)
}
con := &consumer{id: consumerID, conn: conn}
b.consumers.Set(consumerID, con)
q.consumers.Set(consumerID, con)
return consumerID
}
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) string {
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
publisherID, ok := GetPublisherID(ctx)
_, ok = b.queues.Get(queueName)
if !ok {
@@ -292,20 +238,22 @@ func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Co
}
con := &publisher{id: publisherID, conn: conn}
b.publishers.Set(publisherID, con)
return publisherID
return con
}
func (b *Broker) subscribe(ctx context.Context, queueName string, conn net.Conn) {
consumerID := b.addConsumer(ctx, queueName, conn)
go func() {
select {
case <-ctx.Done():
b.removeConsumer(queueName, consumerID)
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
consumerID, ok := GetConsumerID(ctx)
q, ok := b.queues.Get(queueName)
if !ok {
q = b.NewQueue(queueName)
}
}()
con := &consumer{id: consumerID, conn: conn}
b.consumers.Set(consumerID, con)
q.consumers.Set(consumerID, con)
log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName)
return consumerID
}
// Removes connection from the queue and broker
func (b *Broker) removeConsumer(queueName, consumerID string) {
if queue, ok := b.queues.Get(queueName); ok {
con, ok := queue.consumers.Get(consumerID)
@@ -317,57 +265,59 @@ func (b *Broker) removeConsumer(queueName, consumerID string) {
}
}
func (b *Broker) onMessage(ctx context.Context, conn net.Conn, message []byte) error {
var cmdMsg Command
var resultMsg Result
err := json.Unmarshal(message, &cmdMsg)
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
msg, err := b.receive(c)
if err == nil {
return b.handleCommandMessage(ctx, conn, cmdMsg)
}
err = json.Unmarshal(message, &resultMsg)
if err == nil {
return b.handleTaskMessage(ctx, conn, resultMsg)
}
ctx = SetHeaders(ctx, msg.Headers)
b.OnMessage(ctx, msg, c)
return nil
}
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
b.OnClose(ctx, c)
return err
}
b.OnError(ctx, c, err)
return err
}
func (b *Broker) handleTaskMessage(ctx context.Context, _ net.Conn, msg Result) error {
return b.HandleProcessedMessage(ctx, msg)
func (b *Broker) dispatchWorker(queue *Queue) {
delay := b.opts.initialDelay
for task := range queue.tasks {
success := false
for !success && task.RetryCount <= b.opts.maxRetries {
if b.dispatchTaskToConsumer(queue, task) {
success = true
} else {
task.RetryCount++
delay = b.backoffRetry(queue, task, delay)
}
}
}
}
func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error {
status := "PUBLISH"
if msg.Command == consts.REQUEST {
status = "REQUEST"
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
var consumerFound bool
queue.consumers.ForEach(func(_ string, con *consumer) bool {
if err := b.send(con.conn, task.Message); err == nil {
consumerFound = true
return false // break the loop once a consumer is found
}
b.addPublisher(ctx, msg.Queue, conn)
task := Task{
ID: msg.MessageID,
Payload: msg.Payload,
CreatedAt: time.Now(),
CurrentQueue: msg.Queue,
return true
})
if !consumerFound {
log.Printf("No available consumers for queue %s, retrying...", queue.name)
}
result := b.Publish(ctx, task, msg.Queue)
if result.Error != nil {
return result.Error
}
if task.ID != "" {
result.Status = status
result.MessageID = task.ID
result.Queue = msg.Queue
return Write(ctx, conn, result)
}
return nil
return consumerFound
}
func (b *Broker) handleCommandMessage(ctx context.Context, conn net.Conn, msg Command) error {
switch msg.Command {
case consts.SUBSCRIBE:
b.subscribe(ctx, msg.Queue, conn)
return nil
case consts.PUBLISH, consts.REQUEST:
return b.publish(ctx, conn, msg)
default:
return fmt.Errorf("unknown command: %d", msg.Command)
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
time.Sleep(backoffDuration)
queue.tasks <- task
delay *= 2
if delay > b.opts.maxBackoff {
delay = b.opts.maxBackoff
}
return delay
}

View File

@@ -7,10 +7,13 @@ import (
"fmt"
"log"
"net"
"strings"
"sync"
"time"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/jsonparser"
"github.com/oarkflow/mq/utils"
)
@@ -25,16 +28,20 @@ type Consumer struct {
// NewConsumer initializes a new consumer with the provided options.
func NewConsumer(id string, opts ...Option) *Consumer {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
b := &Consumer{
options := setupOptions(opts...)
return &Consumer{
handlers: make(map[string]Handler),
id: id,
opts: options,
}
b.opts = defaultHandlers(options, b.onMessage, b.onClose, b.onError)
return b
}
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
}
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
}
// Close closes the consumer's connection.
@@ -43,90 +50,82 @@ func (c *Consumer) Close() error {
}
// Subscribe to a specific queue.
func (c *Consumer) subscribe(queue string) error {
ctx := context.Background()
ctx = SetHeaders(ctx, map[string]string{
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
headers := WithHeaders(ctx, map[string]string{
consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson,
})
subscribe := Command{
Command: consts.SUBSCRIBE,
Queue: queue,
ID: NewID(),
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers)
if err := c.send(c.conn, msg); err != nil {
return err
}
return c.waitForAck(c.conn)
}
func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error {
fmt.Println("Consumer closed")
return nil
}
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
}
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
headers := WithHeaders(ctx, map[string]string{
consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson,
})
taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err := c.send(conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
}
var task Task
err := json.Unmarshal(msg.Payload, &task)
if err != nil {
log.Println("Error unmarshalling message:", err)
return
}
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
result := c.ProcessTask(ctx, task)
result.MessageID = task.ID
result.Queue = msg.Queue
if result.Error != nil {
result.Status = "FAILED"
} else {
result.Status = "SUCCESS"
}
bt, _ := json.Marshal(result)
reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers)
if err := c.send(conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err)
}
return Write(ctx, c.conn, subscribe)
}
// ProcessTask handles a received task message and invokes the appropriate handler.
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
handler, exists := c.handlers[msg.CurrentQueue]
queue, _ := GetQueue(ctx)
handler, exists := c.handlers[queue]
if !exists {
return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)}
return Result{Error: errors.New("No handler for queue " + queue)}
}
return handler(ctx, msg)
}
// Handle command message sent by the server.
func (c *Consumer) handleCommandMessage(msg Command) error {
switch msg.Command {
case consts.STOP:
return c.Close()
case consts.SUBSCRIBE_ACK:
log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue)
return nil
default:
return fmt.Errorf("unknown command in consumer %d", msg.Command)
}
}
// Handle task message sent by the server.
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
response := c.ProcessTask(ctx, msg)
response.Queue = msg.CurrentQueue
if msg.ID == "" {
response.Error = errors.New("task ID is empty")
response.Command = "error"
} else {
response.Command = "completed"
response.MessageID = msg.ID
}
return c.sendResult(ctx, response)
}
// Send the result of task processing back to the server.
func (c *Consumer) sendResult(ctx context.Context, response Result) error {
return Write(ctx, c.conn, response)
}
// Read and handle incoming messages.
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
var cmdMsg Command
var task Task
err := json.Unmarshal(message, &cmdMsg)
if err == nil && cmdMsg.Command != 0 {
return c.handleCommandMessage(cmdMsg)
}
err = json.Unmarshal(message, &task)
if err == nil {
return c.handleTaskMessage(ctx, task)
}
return nil
}
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
func (c *Consumer) AttemptConnect() error {
var err error
delay := c.opts.initialDelay
for i := 0; i < c.opts.maxRetries; i++ {
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
if err == nil {
c.conn = conn
return nil
}
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
time.Sleep(sleepDuration)
delay *= 2
if delay > c.opts.maxBackoff {
@@ -137,20 +136,19 @@ func (c *Consumer) AttemptConnect() error {
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
}
// onMessage reads incoming messages from the connection.
func (c *Consumer) onMessage(ctx context.Context, conn net.Conn, message []byte) error {
return c.readMessage(ctx, message)
}
// onClose handles connection close event.
func (c *Consumer) onClose(ctx context.Context, conn net.Conn) error {
fmt.Println("Consumer Connection closed", c.id, conn.RemoteAddr())
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(conn)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
c.OnMessage(ctx, msg, conn)
return nil
}
// onError handles errors while reading from the connection.
func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr())
}
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
c.OnClose(ctx, conn)
return err
}
c.OnError(ctx, conn, err)
return err
}
// Consume starts the consumer to consume tasks from the queues.
@@ -159,26 +157,39 @@ func (c *Consumer) Consume(ctx context.Context) error {
if err != nil {
return err
}
for _, q := range c.queues {
if err := c.subscribe(ctx, q); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
}
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ReadFromConn(ctx, c.conn, Handlers{
MessageHandler: c.opts.messageHandler,
CloseHandler: c.opts.closeHandler,
ErrorHandler: c.opts.errorHandler,
})
fmt.Println("Stopping consumer")
for {
if err := c.readMessage(ctx, c.conn); err != nil {
log.Println("Error reading message:", err)
break
}
}
}()
for _, q := range c.queues {
if err := c.subscribe(q); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
}
}
wg.Wait()
return nil
}
func (c *Consumer) waitForAck(conn net.Conn) error {
msg, err := c.receive(conn)
if err != nil {
return err
}
if msg.Command == consts.SUBSCRIBE_ACK {
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
return nil
}
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
}
// RegisterHandler registers a handler for a queue.
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
c.queues = append(c.queues, queue)

118
ctx.go
View File

@@ -1,36 +1,31 @@
package mq
import (
"bufio"
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net"
"os"
"strings"
"time"
"github.com/oarkflow/xid"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
)
type MessageHandler func(context.Context, net.Conn, []byte) error
type CloseHandler func(context.Context, net.Conn) error
type ErrorHandler func(context.Context, net.Conn, error)
type Handlers struct {
MessageHandler MessageHandler
CloseHandler CloseHandler
ErrorHandler ErrorHandler
type Task struct {
ID string `json:"id"`
Payload json.RawMessage `json:"payload"`
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
Status string `json:"status"`
Error error `json:"error"`
}
type Handler func(context.Context, Task) Result
func IsClosed(conn net.Conn) bool {
_, err := conn.Read(make([]byte, 1))
if err != nil {
@@ -52,11 +47,31 @@ func SetHeaders(ctx context.Context, headers map[string]string) context.Context
return context.WithValue(ctx, consts.HeaderKey, hd)
}
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
hd, ok := GetHeaders(ctx)
if !ok {
hd = make(map[string]string)
}
for key, val := range headers {
hd[key] = val
}
return hd
}
func GetHeaders(ctx context.Context) (map[string]string, bool) {
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
return headers, ok
}
func GetHeader(ctx context.Context, key string) (string, bool) {
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
if !ok {
return "", false
}
val, ok := headers[key]
return val, ok
}
func GetContentType(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
@@ -66,6 +81,15 @@ func GetContentType(ctx context.Context) (string, bool) {
return contentType, ok
}
func GetQueue(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
contentType, ok := headers[consts.QueueKey]
return contentType, ok
}
func GetConsumerID(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
@@ -93,70 +117,6 @@ func GetPublisherID(ctx context.Context) (string, bool) {
return contentType, ok
}
func Write(ctx context.Context, conn net.Conn, data any) error {
msg := codec.Message{Headers: make(map[string]string)}
if headers, ok := GetHeaders(ctx); ok {
msg.Headers = headers
}
dataBytes, err := json.Marshal(data)
if err != nil {
return err
}
msg.Payload = dataBytes
messageBytes, err := json.Marshal(msg)
if err != nil {
return err
}
_, err = conn.Write(append(messageBytes, '\n'))
return err
}
func ReadFromConn(ctx context.Context, conn net.Conn, handlers Handlers) {
defer func() {
if handlers.CloseHandler != nil {
if err := handlers.CloseHandler(ctx, conn); err != nil {
fmt.Println("Error in close handler:", err)
}
}
conn.Close()
}()
reader := bufio.NewReader(conn)
for {
messageBytes, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF || IsClosed(conn) || strings.Contains(err.Error(), "closed network connection") {
break
}
if handlers.ErrorHandler != nil {
handlers.ErrorHandler(ctx, conn, err)
}
continue
}
messageBytes = bytes.TrimSpace(messageBytes)
if len(messageBytes) == 0 {
continue
}
var msg codec.Message
err = json.Unmarshal(messageBytes, &msg)
if err != nil {
if handlers.ErrorHandler != nil {
handlers.ErrorHandler(ctx, conn, err)
}
continue
}
ctx = SetHeaders(ctx, msg.Headers)
if handlers.MessageHandler != nil {
err = handlers.MessageHandler(ctx, conn, msg.Payload)
if err != nil {
if handlers.ErrorHandler != nil {
handlers.ErrorHandler(ctx, conn, err)
}
continue
}
}
}
}
func NewID() string {
return xid.New().String()
}

View File

@@ -3,13 +3,13 @@ package dag
import (
"context"
"encoding/json"
"fmt"
"github.com/oarkflow/mq/consts"
"log"
"net/http"
"sync"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/consts"
)
type taskContext struct {
@@ -76,12 +76,20 @@ func (d *DAG) Start(ctx context.Context, addr string) error {
if d.server.SyncMode() {
return nil
}
for _, con := range d.nodes {
go con.Consume(ctx)
}
go func() {
d.server.Start(ctx)
err := d.server.Start(ctx)
if err != nil {
panic(err)
}
}()
for _, con := range d.nodes {
go func(con *mq.Consumer) {
err := con.Consume(ctx)
if err != nil {
panic(err)
}
}(con)
}
log.Printf("HTTP server started on %s", addr)
config := d.server.TLSConfig()
if config.UseTLS {
@@ -90,16 +98,6 @@ func (d *DAG) Start(ctx context.Context, addr string) error {
return http.ListenAndServe(addr, nil)
}
func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) mq.Result {
task := mq.Task{
Payload: payload,
}
if len(taskID) > 0 {
task.ID = taskID[0]
}
return d.server.Publish(ctx, task, queueName)
}
func (d *DAG) FindFirstNode() (string, bool) {
inDegree := make(map[string]int)
for n, _ := range d.nodes {
@@ -121,86 +119,23 @@ func (d *DAG) FindFirstNode() (string, bool) {
return "", false
}
func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result {
return d.sendSync(ctx, mq.Result{Payload: payload})
}
func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result {
if d.FirstNode == "" {
return mq.Result{Error: fmt.Errorf("initial node not defined")}
func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) error {
queue, ok := mq.GetQueue(ctx)
if !ok {
queue = d.FirstNode
}
if d.server.SyncMode() {
return d.sendSync(ctx, mq.Result{Payload: payload})
var id string
if len(taskID) > 0 {
id = taskID[0]
} else {
id = mq.NewID()
}
resultCh := make(chan mq.Result)
result := d.PublishTask(ctx, payload, d.FirstNode)
if result.Error != nil {
return result
task := mq.Task{
ID: id,
Payload: payload,
CreatedAt: time.Now(),
}
d.mu.Lock()
d.taskChMap[result.MessageID] = resultCh
d.mu.Unlock()
finalResult := <-resultCh
return finalResult
}
func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result {
if con, ok := d.nodes[task.Queue]; ok {
return con.ProcessTask(ctx, mq.Task{
ID: task.MessageID,
Payload: task.Payload,
CurrentQueue: task.Queue,
})
}
return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)}
}
func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result {
if task.MessageID == "" {
task.MessageID = mq.NewID()
}
if task.Queue == "" {
task.Queue = d.FirstNode
}
result := d.processNode(ctx, task)
if result.Error != nil {
return result
}
for _, target := range d.loopEdges[task.Queue] {
var items, results []json.RawMessage
if err := json.Unmarshal(result.Payload, &items); err != nil {
return mq.Result{Error: err}
}
for _, item := range items {
result = d.sendSync(ctx, mq.Result{
Command: result.Command,
Payload: item,
Queue: target,
MessageID: result.MessageID,
})
if result.Error != nil {
return result
}
results = append(results, result.Payload)
}
bt, err := json.Marshal(results)
if err != nil {
return mq.Result{Error: err}
}
result.Payload = bt
}
if target, ok := d.edges[task.Queue]; ok {
result = d.sendSync(ctx, mq.Result{
Command: result.Command,
Payload: result.Payload,
Queue: target,
MessageID: result.MessageID,
})
if result.Error != nil {
return result
}
}
return result
return d.server.Publish(ctx, task, queue)
}
func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) {
@@ -264,9 +199,12 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue})
for _, loopNode := range loopNodes {
for _, item := range items {
rs := d.PublishTask(ctx, item, loopNode, task.MessageID)
if rs.Error != nil {
return rs
ctx = mq.SetHeaders(ctx, map[string]string{
consts.QueueKey: loopNode,
})
err := d.PublishTask(ctx, item, task.MessageID)
if err != nil {
return mq.Result{Error: err}
}
}
}
@@ -284,15 +222,14 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
totalItems: 1,
},
}
rs := d.PublishTask(ctx, payload, edge, task.MessageID)
if rs.Error != nil {
return rs
err := d.PublishTask(ctx, payload, edge, task.MessageID)
if err != nil {
return mq.Result{Error: err}
}
} else if completed {
d.mu.Lock()
if resultCh, ok := d.taskChMap[task.MessageID]; ok {
resultCh <- mq.Result{
Command: "complete",
Payload: payload,
Queue: task.Queue,
MessageID: task.MessageID,

View File

@@ -2,9 +2,9 @@ package main
import (
"context"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks"
mq "github.com/oarkflow/mq/v2"
)
func main() {

View File

@@ -2,19 +2,15 @@ package main
import (
"context"
"encoding/json"
"io"
"net/http"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/dag"
"github.com/oarkflow/mq/examples/tasks"
"time"
)
var d *dag.DAG
func main() {
d = dag.New(mq.WithTLS(true, "server.crt", "server.key"), mq.WithCAPath("ca.crt"))
d = dag.New()
d.AddNode("queue1", tasks.Node1)
d.AddNode("queue2", tasks.Node2)
d.AddNode("queue3", tasks.Node3)
@@ -24,45 +20,14 @@ func main() {
d.AddLoop("queue2", "queue3")
d.AddEdge("queue2", "queue4")
d.Prepare()
http.HandleFunc("POST /publish", requestHandler("publish"))
http.HandleFunc("POST /request", requestHandler("request"))
err := d.Start(context.TODO(), ":8083")
go func() {
d.Start(context.Background(), ":8081")
}()
time.Sleep(5 * time.Second)
err := d.PublishTask(context.Background(), []byte(`{"tast": 123}`))
if err != nil {
panic(err)
}
}
func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
var payload []byte
if r.Body != nil {
defer r.Body.Close()
var err error
payload, err = io.ReadAll(r.Body)
if err != nil {
http.Error(w, "Failed to read request body", http.StatusBadRequest)
return
}
} else {
http.Error(w, "Empty request body", http.StatusBadRequest)
return
}
var rs mq.Result
if requestType == "request" {
rs = d.Request(context.Background(), payload)
} else {
rs = d.Send(context.Background(), payload)
}
w.Header().Set("Content-Type", "application/json")
result := map[string]any{
"message_id": rs.MessageID,
"payload": string(rs.Payload),
"error": rs.Error,
}
json.NewEncoder(w).Encode(result)
}
time.Sleep(10 * time.Second)
}

View File

@@ -3,17 +3,16 @@ package main
import (
"context"
"fmt"
mq2 "github.com/oarkflow/mq"
"time"
mq "github.com/oarkflow/mq/v2"
)
func main() {
payload := []byte(`{"message":"Message Publisher \n Task"}`)
task := mq.Task{
task := mq2.Task{
Payload: payload,
}
publisher := mq.NewPublisher("publish-1")
publisher := mq2.NewPublisher("publish-1")
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
err := publisher.Publish(context.Background(), task, "queue1")
if err != nil {
@@ -21,7 +20,7 @@ func main() {
}
fmt.Println("Async task published successfully")
payload = []byte(`{"message":"Fire-and-Forget \n Task"}`)
task = mq.Task{
task = mq2.Task{
Payload: payload,
}
for i := 0; i < 100; i++ {

View File

@@ -2,13 +2,13 @@ package main
import (
"context"
mq2 "github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks"
mq "github.com/oarkflow/mq/v2"
)
func main() {
b := mq.NewBroker(mq.WithCallback(tasks.Callback))
b := mq2.NewBroker(mq2.WithCallback(tasks.Callback))
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
b.NewQueue("queue1")
b.NewQueue("queue2")

View File

@@ -4,42 +4,41 @@ import (
"context"
"encoding/json"
"fmt"
mq "github.com/oarkflow/mq/v2"
mq2 "github.com/oarkflow/mq"
)
func Node1(ctx context.Context, task mq.Task) mq.Result {
func Node1(ctx context.Context, task mq2.Task) mq2.Result {
fmt.Println("Processing queue1", task.ID)
return mq.Result{Payload: task.Payload, MessageID: task.ID}
return mq2.Result{Payload: task.Payload, MessageID: task.ID}
}
func Node2(ctx context.Context, task mq.Task) mq.Result {
return mq.Result{Payload: task.Payload, MessageID: task.ID}
func Node2(ctx context.Context, task mq2.Task) mq2.Result {
return mq2.Result{Payload: task.Payload, MessageID: task.ID}
}
func Node3(ctx context.Context, task mq.Task) mq.Result {
func Node3(ctx context.Context, task mq2.Task) mq2.Result {
var data map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
return mq2.Result{Error: err}
}
data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
bt, _ := json.Marshal(data)
return mq.Result{Payload: bt, MessageID: task.ID}
return mq2.Result{Payload: bt, MessageID: task.ID}
}
func Node4(ctx context.Context, task mq.Task) mq.Result {
func Node4(ctx context.Context, task mq2.Task) mq2.Result {
var data []map[string]any
err := json.Unmarshal(task.Payload, &data)
if err != nil {
return mq.Result{Error: err}
return mq2.Result{Error: err}
}
payload := map[string]any{"storage": data}
bt, _ := json.Marshal(payload)
return mq.Result{Payload: bt, MessageID: task.ID}
return mq2.Result{Payload: bt, MessageID: task.ID}
}
func Callback(ctx context.Context, task mq.Result) mq.Result {
func Callback(ctx context.Context, task mq2.Result) mq2.Result {
fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue)
return mq.Result{}
return mq2.Result{}
}

View File

@@ -2,9 +2,18 @@ package mq
import (
"context"
"encoding/json"
"time"
)
type Result struct {
Payload json.RawMessage `json:"payload"`
Queue string `json:"queue"`
MessageID string `json:"message_id"`
Error error `json:"error,omitempty"`
Status string `json:"status"`
}
type TLSConfig struct {
UseTLS bool
CertPath string
@@ -15,15 +24,16 @@ type TLSConfig struct {
type Options struct {
syncMode bool
brokerAddr string
messageHandler MessageHandler
closeHandler CloseHandler
errorHandler ErrorHandler
callback []func(context.Context, Result) Result
maxRetries int
initialDelay time.Duration
maxBackoff time.Duration
jitterPercent float64
tlsConfig TLSConfig
aesKey json.RawMessage
hmacKey json.RawMessage
enableEncryption bool
queueSize int
}
func defaultOptions() Options {
@@ -34,27 +44,29 @@ func defaultOptions() Options {
initialDelay: 2 * time.Second,
maxBackoff: 20 * time.Second,
jitterPercent: 0.5,
queueSize: 100,
}
}
func defaultHandlers(options Options, onMessage MessageHandler, onClose CloseHandler, onError ErrorHandler) Options {
if options.messageHandler == nil {
options.messageHandler = onMessage
}
if options.closeHandler == nil {
options.closeHandler = onClose
}
if options.errorHandler == nil {
options.errorHandler = onError
}
return options
}
// Option defines a function type for setting options.
type Option func(*Options)
func setupOptions(opts ...Option) Options {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
return options
}
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
return func(opts *Options) {
opts.aesKey = aesKey
opts.hmacKey = hmacKey
opts.enableEncryption = enableEncryption
}
}
// WithBrokerURL -
func WithBrokerURL(url string) Option {
return func(opts *Options) {
@@ -119,24 +131,3 @@ func WithJitterPercent(val float64) Option {
opts.jitterPercent = val
}
}
// WithMessageHandler sets a custom MessageHandler.
func WithMessageHandler(handler MessageHandler) Option {
return func(opts *Options) {
opts.messageHandler = handler
}
}
// WithErrorHandler sets a custom ErrorHandler.
func WithErrorHandler(handler ErrorHandler) Option {
return func(opts *Options) {
opts.errorHandler = handler
}
}
// WithCloseHandler sets a custom CloseHandler.
func WithCloseHandler(handler CloseHandler) Option {
return func(opts *Options) {
opts.closeHandler = handler
}
}

View File

@@ -4,9 +4,13 @@ import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"time"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/jsonparser"
)
type Publisher struct {
@@ -15,31 +19,59 @@ type Publisher struct {
}
func NewPublisher(id string, opts ...Option) *Publisher {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
b := &Publisher{id: id}
b.opts = defaultHandlers(options, nil, b.onClose, b.onError)
return b
options := setupOptions(opts...)
return &Publisher{id: id, opts: options}
}
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
ctx = SetHeaders(ctx, map[string]string{
headers := WithHeaders(ctx, map[string]string{
consts.PublisherKey: p.id,
consts.ContentType: consts.TypeJson,
})
cmd := Command{
ID: NewID(),
Command: command,
Queue: queue,
MessageID: task.ID,
Payload: task.Payload,
if task.ID == "" {
task.ID = NewID()
}
return Write(ctx, conn, cmd)
task.CreatedAt = time.Now()
payload, err := json.Marshal(task)
if err != nil {
return err
}
msg := codec.NewMessage(command, payload, queue, headers)
if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil {
return err
}
return p.waitForAck(conn)
}
func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error {
func (p *Publisher) waitForAck(conn net.Conn) error {
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
if err != nil {
return err
}
if msg.Command == consts.PUBLISH_ACK {
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
return nil
}
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
}
func (p *Publisher) waitForResponse(conn net.Conn) Result {
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
if err != nil {
return Result{Error: err}
}
if msg.Command == consts.RESPONSE {
var result Result
err = json.Unmarshal(msg.Payload, &result)
return result
}
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
return Result{Error: err}
}
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
if err != nil {
return fmt.Errorf("failed to connect to broker: %w", err)
@@ -57,30 +89,22 @@ func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
}
func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Result, error) {
func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result {
ctx = SetHeaders(ctx, map[string]string{
consts.AwaitResponseKey: "true",
})
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
if err != nil {
return Result{Error: err}, fmt.Errorf("failed to connect to broker: %w", err)
err = fmt.Errorf("failed to connect to broker: %w", err)
return Result{Error: err}
}
defer conn.Close()
var result Result
err = p.send(ctx, queue, task, conn, consts.REQUEST)
if err != nil {
return result, err
}
if p.opts.messageHandler == nil {
p.opts.messageHandler = func(ctx context.Context, conn net.Conn, message []byte) error {
err := json.Unmarshal(message, &result)
if err != nil {
return err
}
return conn.Close()
}
}
ReadFromConn(ctx, conn, Handlers{
MessageHandler: p.opts.messageHandler,
CloseHandler: p.opts.closeHandler,
ErrorHandler: p.opts.errorHandler,
})
return result, nil
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
resultCh := make(chan Result)
go func() {
defer close(resultCh)
resultCh <- p.waitForResponse(conn)
}()
finalResult := <-resultCh
return finalResult
}

View File

@@ -1,4 +1,4 @@
package v2
package mq
import (
"github.com/oarkflow/xsync"

View File

@@ -1,4 +1,4 @@
package v2
package mq
import (
"context"

View File

@@ -1,323 +0,0 @@
package v2
import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
"strings"
"time"
"github.com/oarkflow/xsync"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/jsonparser"
"github.com/oarkflow/mq/utils"
)
type QueuedTask struct {
Message *codec.Message
RetryCount int
}
type consumer struct {
id string
conn net.Conn
}
type publisher struct {
id string
conn net.Conn
}
type Broker struct {
queues xsync.IMap[string, *Queue]
consumers xsync.IMap[string, *consumer]
publishers xsync.IMap[string, *publisher]
opts Options
}
func NewBroker(opts ...Option) *Broker {
options := setupOptions(opts...)
return &Broker{
queues: xsync.NewMap[string, *Queue](),
publishers: xsync.NewMap[string, *publisher](),
consumers: xsync.NewMap[string, *consumer](),
opts: options,
}
}
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
consumerID, ok := GetConsumerID(ctx)
if ok && consumerID != "" {
if con, exists := b.consumers.Get(consumerID); exists {
con.conn.Close()
b.consumers.Del(consumerID)
}
b.queues.ForEach(func(_ string, queue *Queue) bool {
queue.consumers.Del(consumerID)
return true
})
}
publisherID, ok := GetPublisherID(ctx)
if ok && publisherID != "" {
if con, exists := b.publishers.Get(publisherID); exists {
con.conn.Close()
b.publishers.Del(publisherID)
}
}
return nil
}
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
}
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
switch msg.Command {
case consts.PUBLISH:
b.PublishHandler(ctx, conn, msg)
case consts.SUBSCRIBE:
b.SubscribeHandler(ctx, conn, msg)
case consts.MESSAGE_RESPONSE:
b.MessageResponseHandler(ctx, msg)
case consts.MESSAGE_ACK:
b.MessageAck(ctx, msg)
}
}
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
consumerID, _ := GetConsumerID(ctx)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
}
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
msg.Command = consts.RESPONSE
headers, ok := GetHeaders(ctx)
if !ok {
return
}
b.HandleCallback(ctx, msg)
awaitResponse, ok := headers[consts.AwaitResponseKey]
if !(ok && awaitResponse == "true") {
return
}
publisherID, exists := headers[consts.PublisherKey]
if !exists {
return
}
con, ok := b.publishers.Get(publisherID)
if !ok {
return
}
err := b.send(con.conn, msg)
if err != nil {
panic(err)
}
}
func (b *Broker) Publish(ctx context.Context, task Task, queue string) error {
headers, _ := GetHeaders(ctx)
msg := codec.NewMessage(consts.PUBLISH, task.Payload, queue, headers)
b.broadcastToConsumers(ctx, msg)
return nil
}
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
pub := b.addPublisher(ctx, msg.Queue, conn)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
}
b.broadcastToConsumers(ctx, msg)
go func() {
select {
case <-ctx.Done():
b.publishers.Del(pub.id)
}
}()
}
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
consumerID := b.addConsumer(ctx, msg.Queue, conn)
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
}
go func() {
select {
case <-ctx.Done():
b.removeConsumer(msg.Queue, consumerID)
}
}()
}
func (b *Broker) Start(ctx context.Context) error {
var listener net.Listener
var err error
if b.opts.tlsConfig.UseTLS {
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
if err != nil {
return fmt.Errorf("failed to load TLS certificates: %v", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
if err != nil {
return fmt.Errorf("failed to start TLS listener: %v", err)
}
log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr)
} else {
listener, err = net.Listen("tcp", b.opts.brokerAddr)
if err != nil {
return fmt.Errorf("failed to start TCP listener: %v", err)
}
log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr)
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
b.OnError(ctx, conn, err)
continue
}
go func(c net.Conn) {
defer c.Close()
for {
err := b.readMessage(ctx, c)
if err != nil {
break
}
}
}(conn)
}
}
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
}
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
}
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
if queue, ok := b.queues.Get(msg.Queue); ok {
task := &QueuedTask{Message: msg, RetryCount: 0}
queue.tasks <- task
log.Printf("Task enqueued for queue %s", msg.Queue)
}
}
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
msg, err := b.receive(conn)
if err != nil {
return err
}
if msg.Command == consts.MESSAGE_ACK {
log.Println("Received CONSUMER_ACK: Subscribed successfully")
return nil
}
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
}
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
publisherID, ok := GetPublisherID(ctx)
_, ok = b.queues.Get(queueName)
if !ok {
b.NewQueue(queueName)
}
con := &publisher{id: publisherID, conn: conn}
b.publishers.Set(publisherID, con)
return con
}
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
consumerID, ok := GetConsumerID(ctx)
q, ok := b.queues.Get(queueName)
if !ok {
q = b.NewQueue(queueName)
}
con := &consumer{id: consumerID, conn: conn}
b.consumers.Set(consumerID, con)
q.consumers.Set(consumerID, con)
log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName)
return consumerID
}
func (b *Broker) removeConsumer(queueName, consumerID string) {
if queue, ok := b.queues.Get(queueName); ok {
con, ok := queue.consumers.Get(consumerID)
if ok {
con.conn.Close()
queue.consumers.Del(consumerID)
}
b.queues.Del(queueName)
}
}
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
msg, err := b.receive(c)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
b.OnMessage(ctx, msg, c)
return nil
}
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
b.OnClose(ctx, c)
return err
}
b.OnError(ctx, c, err)
return err
}
func (b *Broker) dispatchWorker(queue *Queue) {
delay := b.opts.initialDelay
for task := range queue.tasks {
success := false
for !success && task.RetryCount <= b.opts.maxRetries {
if b.dispatchTaskToConsumer(queue, task) {
success = true
} else {
task.RetryCount++
delay = b.backoffRetry(queue, task, delay)
}
}
}
}
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
var consumerFound bool
queue.consumers.ForEach(func(_ string, con *consumer) bool {
if err := b.send(con.conn, task.Message); err == nil {
consumerFound = true
log.Printf("Task dispatched to consumer %s on queue %s", con.id, queue.name)
return false // break the loop once a consumer is found
}
return true
})
if !consumerFound {
log.Printf("No available consumers for queue %s, retrying...", queue.name)
}
return consumerFound
}
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
time.Sleep(backoffDuration)
queue.tasks <- task
delay *= 2
if delay > b.opts.maxBackoff {
delay = b.opts.maxBackoff
}
return delay
}

View File

@@ -1,197 +0,0 @@
package v2
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"strings"
"sync"
"time"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/jsonparser"
"github.com/oarkflow/mq/utils"
)
// Consumer structure to hold consumer-specific configurations and state.
type Consumer struct {
id string
handlers map[string]Handler
conn net.Conn
queues []string
opts Options
}
// NewConsumer initializes a new consumer with the provided options.
func NewConsumer(id string, opts ...Option) *Consumer {
options := setupOptions(opts...)
return &Consumer{
handlers: make(map[string]Handler),
id: id,
opts: options,
}
}
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
}
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
}
// Close closes the consumer's connection.
func (c *Consumer) Close() error {
return c.conn.Close()
}
// Subscribe to a specific queue.
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
headers := WithHeaders(ctx, map[string]string{
consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson,
})
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers)
if err := c.send(c.conn, msg); err != nil {
return err
}
return c.waitForAck(c.conn)
}
func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error {
fmt.Println("Consumer closed")
return nil
}
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
}
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
headers := WithHeaders(ctx, map[string]string{
consts.ConsumerKey: c.id,
consts.ContentType: consts.TypeJson,
})
taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err := c.send(conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
}
var task Task
err := json.Unmarshal(msg.Payload, &task)
if err != nil {
log.Println("Error unmarshalling message:", err)
return
}
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
result := c.ProcessTask(ctx, task)
result.MessageID = task.ID
result.Queue = msg.Queue
if result.Error != nil {
result.Status = "FAILED"
} else {
result.Status = "SUCCESS"
}
bt, _ := json.Marshal(result)
reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers)
if err := c.send(conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err)
}
}
// ProcessTask handles a received task message and invokes the appropriate handler.
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
queue, _ := GetQueue(ctx)
handler, exists := c.handlers[queue]
if !exists {
return Result{Error: errors.New("No handler for queue " + queue)}
}
return handler(ctx, msg)
}
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
func (c *Consumer) AttemptConnect() error {
var err error
delay := c.opts.initialDelay
for i := 0; i < c.opts.maxRetries; i++ {
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
if err == nil {
c.conn = conn
return nil
}
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
time.Sleep(sleepDuration)
delay *= 2
if delay > c.opts.maxBackoff {
delay = c.opts.maxBackoff
}
}
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
}
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(conn)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
c.OnMessage(ctx, msg, conn)
return nil
}
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
c.OnClose(ctx, conn)
return err
}
c.OnError(ctx, conn, err)
return err
}
// Consume starts the consumer to consume tasks from the queues.
func (c *Consumer) Consume(ctx context.Context) error {
err := c.AttemptConnect()
if err != nil {
return err
}
for _, q := range c.queues {
if err := c.subscribe(ctx, q); err != nil {
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
}
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
if err := c.readMessage(ctx, c.conn); err != nil {
log.Println("Error reading message:", err)
break
}
}
}()
wg.Wait()
return nil
}
func (c *Consumer) waitForAck(conn net.Conn) error {
msg, err := c.receive(conn)
if err != nil {
return err
}
if msg.Command == consts.SUBSCRIBE_ACK {
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
return nil
}
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
}
// RegisterHandler registers a handler for a queue.
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
c.queues = append(c.queues, queue)
c.handlers[queue] = handler
}

158
v2/ctx.go
View File

@@ -1,158 +0,0 @@
package v2
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"net"
"os"
"time"
"github.com/oarkflow/xid"
"github.com/oarkflow/mq/consts"
)
type Task struct {
ID string `json:"id"`
Payload json.RawMessage `json:"payload"`
CreatedAt time.Time `json:"created_at"`
ProcessedAt time.Time `json:"processed_at"`
Status string `json:"status"`
Error error `json:"error"`
}
type Handler func(context.Context, Task) Result
func IsClosed(conn net.Conn) bool {
_, err := conn.Read(make([]byte, 1))
if err != nil {
if err == net.ErrClosed {
return true
}
}
return false
}
func SetHeaders(ctx context.Context, headers map[string]string) context.Context {
hd, ok := GetHeaders(ctx)
if !ok {
hd = make(map[string]string)
}
for key, val := range headers {
hd[key] = val
}
return context.WithValue(ctx, consts.HeaderKey, hd)
}
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
hd, ok := GetHeaders(ctx)
if !ok {
hd = make(map[string]string)
}
for key, val := range headers {
hd[key] = val
}
return hd
}
func GetHeaders(ctx context.Context) (map[string]string, bool) {
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
return headers, ok
}
func GetHeader(ctx context.Context, key string) (string, bool) {
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
if !ok {
return "", false
}
val, ok := headers[key]
return val, ok
}
func GetContentType(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
contentType, ok := headers[consts.ContentType]
return contentType, ok
}
func GetQueue(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
contentType, ok := headers[consts.QueueKey]
return contentType, ok
}
func GetConsumerID(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
contentType, ok := headers[consts.ConsumerKey]
return contentType, ok
}
func GetTriggerNode(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
contentType, ok := headers[consts.TriggerNode]
return contentType, ok
}
func GetPublisherID(ctx context.Context) (string, bool) {
headers, ok := GetHeaders(ctx)
if !ok {
return "", false
}
contentType, ok := headers[consts.PublisherKey]
return contentType, ok
}
func NewID() string {
return xid.New().String()
}
func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
InsecureSkipVerify: true,
}
if len(caPath) > 0 && caPath[0] != "" {
caCert, err := os.ReadFile(caPath[0])
if err != nil {
return nil, fmt.Errorf("failed to load CA cert: %w", err)
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
conn, err := tls.Dial("tcp", addr, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to dial TLS connection: %w", err)
}
return conn, nil
}
func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
if config.UseTLS {
return createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath)
} else {
return net.Dial("tcp", addr)
}
}

View File

@@ -1,133 +0,0 @@
package v2
import (
"context"
"encoding/json"
"time"
)
type Result struct {
Payload json.RawMessage `json:"payload"`
Queue string `json:"queue"`
MessageID string `json:"message_id"`
Error error `json:"error,omitempty"`
Status string `json:"status"`
}
type TLSConfig struct {
UseTLS bool
CertPath string
KeyPath string
CAPath string
}
type Options struct {
syncMode bool
brokerAddr string
callback []func(context.Context, Result) Result
maxRetries int
initialDelay time.Duration
maxBackoff time.Duration
jitterPercent float64
tlsConfig TLSConfig
aesKey json.RawMessage
hmacKey json.RawMessage
enableEncryption bool
queueSize int
}
func defaultOptions() Options {
return Options{
syncMode: false,
brokerAddr: ":8080",
maxRetries: 5,
initialDelay: 2 * time.Second,
maxBackoff: 20 * time.Second,
jitterPercent: 0.5,
queueSize: 100,
}
}
// Option defines a function type for setting options.
type Option func(*Options)
func setupOptions(opts ...Option) Options {
options := defaultOptions()
for _, opt := range opts {
opt(&options)
}
return options
}
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
return func(opts *Options) {
opts.aesKey = aesKey
opts.hmacKey = hmacKey
opts.enableEncryption = enableEncryption
}
}
// WithBrokerURL -
func WithBrokerURL(url string) Option {
return func(opts *Options) {
opts.brokerAddr = url
}
}
// WithTLS - Option to enable/disable TLS
func WithTLS(enableTLS bool, certPath, keyPath string) Option {
return func(o *Options) {
o.tlsConfig.UseTLS = enableTLS
o.tlsConfig.CertPath = certPath
o.tlsConfig.KeyPath = keyPath
}
}
// WithCAPath - Option to enable/disable TLS
func WithCAPath(caPath string) Option {
return func(o *Options) {
o.tlsConfig.CAPath = caPath
}
}
// WithSyncMode -
func WithSyncMode(mode bool) Option {
return func(opts *Options) {
opts.syncMode = mode
}
}
// WithMaxRetries -
func WithMaxRetries(val int) Option {
return func(opts *Options) {
opts.maxRetries = val
}
}
// WithInitialDelay -
func WithInitialDelay(val time.Duration) Option {
return func(opts *Options) {
opts.initialDelay = val
}
}
// WithMaxBackoff -
func WithMaxBackoff(val time.Duration) Option {
return func(opts *Options) {
opts.maxBackoff = val
}
}
// WithCallback -
func WithCallback(val ...func(context.Context, Result) Result) Option {
return func(opts *Options) {
opts.callback = val
}
}
// WithJitterPercent -
func WithJitterPercent(val float64) Option {
return func(opts *Options) {
opts.jitterPercent = val
}
}

View File

@@ -1,110 +0,0 @@
package v2
import (
"context"
"encoding/json"
"fmt"
"log"
"net"
"time"
"github.com/oarkflow/mq/codec"
"github.com/oarkflow/mq/consts"
"github.com/oarkflow/mq/jsonparser"
)
type Publisher struct {
id string
opts Options
}
func NewPublisher(id string, opts ...Option) *Publisher {
options := setupOptions(opts...)
return &Publisher{id: id, opts: options}
}
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
headers := WithHeaders(ctx, map[string]string{
consts.PublisherKey: p.id,
consts.ContentType: consts.TypeJson,
})
if task.ID == "" {
task.ID = NewID()
}
task.CreatedAt = time.Now()
payload, err := json.Marshal(task)
if err != nil {
return err
}
msg := codec.NewMessage(command, payload, queue, headers)
if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil {
return err
}
return p.waitForAck(conn)
}
func (p *Publisher) waitForAck(conn net.Conn) error {
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
if err != nil {
return err
}
if msg.Command == consts.PUBLISH_ACK {
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
return nil
}
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
}
func (p *Publisher) waitForResponse(conn net.Conn) Result {
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
if err != nil {
return Result{Error: err}
}
if msg.Command == consts.RESPONSE {
var result Result
err = json.Unmarshal(msg.Payload, &result)
return result
}
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
return Result{Error: err}
}
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
if err != nil {
return fmt.Errorf("failed to connect to broker: %w", err)
}
defer conn.Close()
return p.send(ctx, queue, task, conn, consts.PUBLISH)
}
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error {
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
return nil
}
func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
}
func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result {
ctx = SetHeaders(ctx, map[string]string{
consts.AwaitResponseKey: "true",
})
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
if err != nil {
err = fmt.Errorf("failed to connect to broker: %w", err)
return Result{Error: err}
}
defer conn.Close()
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
resultCh := make(chan Result)
go func() {
defer close(resultCh)
resultCh <- p.waitForResponse(conn)
}()
finalResult := <-resultCh
return finalResult
}