mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-05 16:06:55 +08:00
feat: separate broker
This commit is contained in:
410
broker.go
410
broker.go
@@ -4,37 +4,35 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xsync"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/jsonparser"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
type QueuedTask struct {
|
||||
Message *codec.Message
|
||||
RetryCount int
|
||||
}
|
||||
|
||||
type consumer struct {
|
||||
id string
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (p *consumer) send(ctx context.Context, cmd any) error {
|
||||
return Write(ctx, p.conn, cmd)
|
||||
}
|
||||
|
||||
type publisher struct {
|
||||
id string
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
func (p *publisher) send(ctx context.Context, cmd any) error {
|
||||
return Write(ctx, p.conn, cmd)
|
||||
}
|
||||
|
||||
type Handler func(context.Context, Task) Result
|
||||
|
||||
type Broker struct {
|
||||
queues xsync.IMap[string, *Queue]
|
||||
consumers xsync.IMap[string, *consumer]
|
||||
@@ -42,100 +40,17 @@ type Broker struct {
|
||||
opts Options
|
||||
}
|
||||
|
||||
type Queue struct {
|
||||
name string
|
||||
consumers xsync.IMap[string, *consumer]
|
||||
messages xsync.IMap[string, *Task]
|
||||
deferred xsync.IMap[string, *Task]
|
||||
}
|
||||
|
||||
func newQueue(name string) *Queue {
|
||||
return &Queue{
|
||||
name: name,
|
||||
consumers: xsync.NewMap[string, *consumer](),
|
||||
messages: xsync.NewMap[string, *Task](),
|
||||
deferred: xsync.NewMap[string, *Task](),
|
||||
}
|
||||
}
|
||||
|
||||
func (queue *Queue) send(ctx context.Context, cmd any) {
|
||||
queue.consumers.ForEach(func(_ string, client *consumer) bool {
|
||||
err := client.send(ctx, cmd)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ProcessedAt time.Time `json:"processed_at"`
|
||||
CurrentQueue string `json:"current_queue"`
|
||||
Status string `json:"status"`
|
||||
Error error `json:"error"`
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
ID string `json:"id"`
|
||||
Command consts.CMD `json:"command"`
|
||||
Queue string `json:"queue"`
|
||||
MessageID string `json:"message_id"`
|
||||
Payload json.RawMessage `json:"payload,omitempty"` // Used for carrying the task payload
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type Result struct {
|
||||
Command string `json:"command"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
Queue string `json:"queue"`
|
||||
MessageID string `json:"message_id"`
|
||||
Error error `json:"error"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func NewBroker(opts ...Option) *Broker {
|
||||
options := defaultOptions()
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
b := &Broker{
|
||||
options := setupOptions(opts...)
|
||||
return &Broker{
|
||||
queues: xsync.NewMap[string, *Queue](),
|
||||
publishers: xsync.NewMap[string, *publisher](),
|
||||
consumers: xsync.NewMap[string, *consumer](),
|
||||
opts: options,
|
||||
}
|
||||
b.opts = defaultHandlers(options, b.onMessage, b.onClose, b.onError)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Broker) Send(ctx context.Context, cmd Command) error {
|
||||
queue, ok := b.queues.Get(cmd.Queue)
|
||||
if !ok || queue == nil {
|
||||
return errors.New("invalid queue or not exists")
|
||||
}
|
||||
queue.send(ctx, cmd)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) TLSConfig() TLSConfig {
|
||||
return b.opts.tlsConfig
|
||||
}
|
||||
|
||||
func (b *Broker) SyncMode() bool {
|
||||
return b.opts.syncMode
|
||||
}
|
||||
|
||||
func (b *Broker) sendToPublisher(ctx context.Context, publisherID string, result Result) error {
|
||||
pub, ok := b.publishers.Get(publisherID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return pub.send(ctx, result)
|
||||
}
|
||||
|
||||
func (b *Broker) onClose(ctx context.Context, _ net.Conn) error {
|
||||
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
if ok && consumerID != "" {
|
||||
if con, exists := b.consumers.Get(consumerID); exists {
|
||||
@@ -157,11 +72,94 @@ func (b *Broker) onClose(ctx context.Context, _ net.Conn) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) onError(_ context.Context, conn net.Conn, err error) {
|
||||
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
||||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
// Start the broker server with optional TLS support
|
||||
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
switch msg.Command {
|
||||
case consts.PUBLISH:
|
||||
b.PublishHandler(ctx, conn, msg)
|
||||
case consts.SUBSCRIBE:
|
||||
b.SubscribeHandler(ctx, conn, msg)
|
||||
case consts.MESSAGE_RESPONSE:
|
||||
b.MessageResponseHandler(ctx, msg)
|
||||
case consts.MESSAGE_ACK:
|
||||
b.MessageAck(ctx, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
|
||||
consumerID, _ := GetConsumerID(ctx)
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
|
||||
}
|
||||
|
||||
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
|
||||
msg.Command = consts.RESPONSE
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.HandleCallback(ctx, msg)
|
||||
awaitResponse, ok := headers[consts.AwaitResponseKey]
|
||||
if !(ok && awaitResponse == "true") {
|
||||
return
|
||||
}
|
||||
publisherID, exists := headers[consts.PublisherKey]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
con, ok := b.publishers.Get(publisherID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
err := b.send(con.conn, msg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) Publish(ctx context.Context, task Task, queue string) error {
|
||||
headers, _ := GetHeaders(ctx)
|
||||
payload, _ := json.Marshal(task)
|
||||
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers)
|
||||
b.broadcastToConsumers(ctx, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
pub := b.addPublisher(ctx, msg.Queue, conn)
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
||||
|
||||
ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||||
if err := b.send(conn, ack); err != nil {
|
||||
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||||
}
|
||||
b.broadcastToConsumers(ctx, msg)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.publishers.Del(pub.id)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
consumerID := b.addConsumer(ctx, msg.Queue, conn)
|
||||
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||||
if err := b.send(conn, ack); err != nil {
|
||||
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.removeConsumer(msg.Queue, consumerID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *Broker) Start(ctx context.Context) error {
|
||||
var listener net.Listener
|
||||
var err error
|
||||
@@ -178,113 +176,61 @@ func (b *Broker) Start(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start TLS listener: %v", err)
|
||||
}
|
||||
log.Println("TLS server started on", b.opts.brokerAddr)
|
||||
log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr)
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start TCP listener: %v", err)
|
||||
}
|
||||
log.Println("TCP server started on", b.opts.brokerAddr)
|
||||
log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr)
|
||||
}
|
||||
defer listener.Close()
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
fmt.Println("Error accepting connection:", err)
|
||||
b.OnError(ctx, conn, err)
|
||||
continue
|
||||
}
|
||||
go ReadFromConn(ctx, conn, Handlers{
|
||||
MessageHandler: b.opts.messageHandler,
|
||||
CloseHandler: b.opts.closeHandler,
|
||||
ErrorHandler: b.opts.errorHandler,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) Result {
|
||||
queue, task, err := b.AddMessageToQueue(&message, queueName)
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
for {
|
||||
err := b.readMessage(ctx, c)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
break
|
||||
}
|
||||
result := Result{
|
||||
Command: "PUBLISH",
|
||||
Payload: message.Payload,
|
||||
Queue: queueName,
|
||||
MessageID: task.ID,
|
||||
}
|
||||
if queue.consumers.Size() == 0 {
|
||||
queue.deferred.Set(NewID(), &message)
|
||||
fmt.Println("task deferred as no consumers are connected", queueName)
|
||||
return result
|
||||
}(conn)
|
||||
}
|
||||
queue.send(ctx, message)
|
||||
return result
|
||||
}
|
||||
|
||||
func (b *Broker) NewQueue(qName string) *Queue {
|
||||
q, ok := b.queues.Get(qName)
|
||||
if ok {
|
||||
return q
|
||||
}
|
||||
q = newQueue(qName)
|
||||
b.queues.Set(qName, q)
|
||||
return q
|
||||
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||
}
|
||||
|
||||
func (b *Broker) AddMessageToQueue(task *Task, queueName string) (*Queue, *Task, error) {
|
||||
queue := b.NewQueue(queueName)
|
||||
if task.ID == "" {
|
||||
task.ID = NewID()
|
||||
}
|
||||
if queueName != "" {
|
||||
task.CurrentQueue = queueName
|
||||
}
|
||||
task.CreatedAt = time.Now()
|
||||
queue.messages.Set(task.ID, task)
|
||||
return queue, task, nil
|
||||
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||
}
|
||||
|
||||
func (b *Broker) HandleProcessedMessage(ctx context.Context, result Result) error {
|
||||
publisherID, ok := GetPublisherID(ctx)
|
||||
if ok && publisherID != "" {
|
||||
err := b.sendToPublisher(ctx, publisherID, result)
|
||||
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
||||
if queue, ok := b.queues.Get(msg.Queue); ok {
|
||||
task := &QueuedTask{Message: msg, RetryCount: 0}
|
||||
queue.tasks <- task
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
|
||||
msg, err := b.receive(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, callback := range b.opts.callback {
|
||||
if callback != nil {
|
||||
rs := callback(ctx, result)
|
||||
if rs.Error != nil {
|
||||
return rs.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
if msg.Command == consts.MESSAGE_ACK {
|
||||
log.Println("Received CONSUMER_ACK: Subscribed successfully")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
|
||||
}
|
||||
|
||||
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
defer func() {
|
||||
cmd := Command{
|
||||
Command: consts.SUBSCRIBE_ACK,
|
||||
Queue: queueName,
|
||||
Error: "",
|
||||
}
|
||||
Write(ctx, conn, cmd)
|
||||
log.Printf("Consumer %s joined server on queue %s", consumerID, queueName)
|
||||
}()
|
||||
q, ok := b.queues.Get(queueName)
|
||||
if !ok {
|
||||
q = b.NewQueue(queueName)
|
||||
}
|
||||
con := &consumer{id: consumerID, conn: conn}
|
||||
b.consumers.Set(consumerID, con)
|
||||
q.consumers.Set(consumerID, con)
|
||||
return consumerID
|
||||
}
|
||||
|
||||
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) string {
|
||||
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
|
||||
publisherID, ok := GetPublisherID(ctx)
|
||||
_, ok = b.queues.Get(queueName)
|
||||
if !ok {
|
||||
@@ -292,20 +238,22 @@ func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Co
|
||||
}
|
||||
con := &publisher{id: publisherID, conn: conn}
|
||||
b.publishers.Set(publisherID, con)
|
||||
return publisherID
|
||||
return con
|
||||
}
|
||||
|
||||
func (b *Broker) subscribe(ctx context.Context, queueName string, conn net.Conn) {
|
||||
consumerID := b.addConsumer(ctx, queueName, conn)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.removeConsumer(queueName, consumerID)
|
||||
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
q, ok := b.queues.Get(queueName)
|
||||
if !ok {
|
||||
q = b.NewQueue(queueName)
|
||||
}
|
||||
}()
|
||||
con := &consumer{id: consumerID, conn: conn}
|
||||
b.consumers.Set(consumerID, con)
|
||||
q.consumers.Set(consumerID, con)
|
||||
log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName)
|
||||
return consumerID
|
||||
}
|
||||
|
||||
// Removes connection from the queue and broker
|
||||
func (b *Broker) removeConsumer(queueName, consumerID string) {
|
||||
if queue, ok := b.queues.Get(queueName); ok {
|
||||
con, ok := queue.consumers.Get(consumerID)
|
||||
@@ -317,57 +265,59 @@ func (b *Broker) removeConsumer(queueName, consumerID string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) onMessage(ctx context.Context, conn net.Conn, message []byte) error {
|
||||
var cmdMsg Command
|
||||
var resultMsg Result
|
||||
err := json.Unmarshal(message, &cmdMsg)
|
||||
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||
msg, err := b.receive(c)
|
||||
if err == nil {
|
||||
return b.handleCommandMessage(ctx, conn, cmdMsg)
|
||||
}
|
||||
err = json.Unmarshal(message, &resultMsg)
|
||||
if err == nil {
|
||||
return b.handleTaskMessage(ctx, conn, resultMsg)
|
||||
}
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
b.OnMessage(ctx, msg, c)
|
||||
return nil
|
||||
}
|
||||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||
b.OnClose(ctx, c)
|
||||
return err
|
||||
}
|
||||
b.OnError(ctx, c, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Broker) handleTaskMessage(ctx context.Context, _ net.Conn, msg Result) error {
|
||||
return b.HandleProcessedMessage(ctx, msg)
|
||||
func (b *Broker) dispatchWorker(queue *Queue) {
|
||||
delay := b.opts.initialDelay
|
||||
for task := range queue.tasks {
|
||||
success := false
|
||||
for !success && task.RetryCount <= b.opts.maxRetries {
|
||||
if b.dispatchTaskToConsumer(queue, task) {
|
||||
success = true
|
||||
} else {
|
||||
task.RetryCount++
|
||||
delay = b.backoffRetry(queue, task, delay)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error {
|
||||
status := "PUBLISH"
|
||||
if msg.Command == consts.REQUEST {
|
||||
status = "REQUEST"
|
||||
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
||||
var consumerFound bool
|
||||
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
||||
if err := b.send(con.conn, task.Message); err == nil {
|
||||
consumerFound = true
|
||||
return false // break the loop once a consumer is found
|
||||
}
|
||||
b.addPublisher(ctx, msg.Queue, conn)
|
||||
task := Task{
|
||||
ID: msg.MessageID,
|
||||
Payload: msg.Payload,
|
||||
CreatedAt: time.Now(),
|
||||
CurrentQueue: msg.Queue,
|
||||
return true
|
||||
})
|
||||
if !consumerFound {
|
||||
log.Printf("No available consumers for queue %s, retrying...", queue.name)
|
||||
}
|
||||
result := b.Publish(ctx, task, msg.Queue)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if task.ID != "" {
|
||||
result.Status = status
|
||||
result.MessageID = task.ID
|
||||
result.Queue = msg.Queue
|
||||
return Write(ctx, conn, result)
|
||||
}
|
||||
return nil
|
||||
return consumerFound
|
||||
}
|
||||
|
||||
func (b *Broker) handleCommandMessage(ctx context.Context, conn net.Conn, msg Command) error {
|
||||
switch msg.Command {
|
||||
case consts.SUBSCRIBE:
|
||||
b.subscribe(ctx, msg.Queue, conn)
|
||||
return nil
|
||||
case consts.PUBLISH, consts.REQUEST:
|
||||
return b.publish(ctx, conn, msg)
|
||||
default:
|
||||
return fmt.Errorf("unknown command: %d", msg.Command)
|
||||
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
|
||||
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
|
||||
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
|
||||
time.Sleep(backoffDuration)
|
||||
queue.tasks <- task
|
||||
delay *= 2
|
||||
if delay > b.opts.maxBackoff {
|
||||
delay = b.opts.maxBackoff
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
193
consumer.go
193
consumer.go
@@ -7,10 +7,13 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/jsonparser"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
@@ -25,16 +28,20 @@ type Consumer struct {
|
||||
|
||||
// NewConsumer initializes a new consumer with the provided options.
|
||||
func NewConsumer(id string, opts ...Option) *Consumer {
|
||||
options := defaultOptions()
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
b := &Consumer{
|
||||
options := setupOptions(opts...)
|
||||
return &Consumer{
|
||||
handlers: make(map[string]Handler),
|
||||
id: id,
|
||||
opts: options,
|
||||
}
|
||||
b.opts = defaultHandlers(options, b.onMessage, b.onClose, b.onError)
|
||||
return b
|
||||
}
|
||||
|
||||
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||
}
|
||||
|
||||
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||
}
|
||||
|
||||
// Close closes the consumer's connection.
|
||||
@@ -43,90 +50,82 @@ func (c *Consumer) Close() error {
|
||||
}
|
||||
|
||||
// Subscribe to a specific queue.
|
||||
func (c *Consumer) subscribe(queue string) error {
|
||||
ctx := context.Background()
|
||||
ctx = SetHeaders(ctx, map[string]string{
|
||||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||
headers := WithHeaders(ctx, map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
subscribe := Command{
|
||||
Command: consts.SUBSCRIBE,
|
||||
Queue: queue,
|
||||
ID: NewID(),
|
||||
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers)
|
||||
if err := c.send(c.conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.waitForAck(c.conn)
|
||||
}
|
||||
|
||||
func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error {
|
||||
fmt.Println("Consumer closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
||||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
headers := WithHeaders(ctx, map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||
if err := c.send(conn, reply); err != nil {
|
||||
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||||
}
|
||||
var task Task
|
||||
err := json.Unmarshal(msg.Payload, &task)
|
||||
if err != nil {
|
||||
log.Println("Error unmarshalling message:", err)
|
||||
return
|
||||
}
|
||||
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
||||
result := c.ProcessTask(ctx, task)
|
||||
result.MessageID = task.ID
|
||||
result.Queue = msg.Queue
|
||||
if result.Error != nil {
|
||||
result.Status = "FAILED"
|
||||
} else {
|
||||
result.Status = "SUCCESS"
|
||||
}
|
||||
bt, _ := json.Marshal(result)
|
||||
reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers)
|
||||
if err := c.send(conn, reply); err != nil {
|
||||
fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err)
|
||||
}
|
||||
return Write(ctx, c.conn, subscribe)
|
||||
}
|
||||
|
||||
// ProcessTask handles a received task message and invokes the appropriate handler.
|
||||
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
||||
handler, exists := c.handlers[msg.CurrentQueue]
|
||||
queue, _ := GetQueue(ctx)
|
||||
handler, exists := c.handlers[queue]
|
||||
if !exists {
|
||||
return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)}
|
||||
return Result{Error: errors.New("No handler for queue " + queue)}
|
||||
}
|
||||
return handler(ctx, msg)
|
||||
}
|
||||
|
||||
// Handle command message sent by the server.
|
||||
func (c *Consumer) handleCommandMessage(msg Command) error {
|
||||
switch msg.Command {
|
||||
case consts.STOP:
|
||||
return c.Close()
|
||||
case consts.SUBSCRIBE_ACK:
|
||||
log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown command in consumer %d", msg.Command)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle task message sent by the server.
|
||||
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
|
||||
response := c.ProcessTask(ctx, msg)
|
||||
response.Queue = msg.CurrentQueue
|
||||
if msg.ID == "" {
|
||||
response.Error = errors.New("task ID is empty")
|
||||
response.Command = "error"
|
||||
} else {
|
||||
response.Command = "completed"
|
||||
response.MessageID = msg.ID
|
||||
}
|
||||
return c.sendResult(ctx, response)
|
||||
}
|
||||
|
||||
// Send the result of task processing back to the server.
|
||||
func (c *Consumer) sendResult(ctx context.Context, response Result) error {
|
||||
return Write(ctx, c.conn, response)
|
||||
}
|
||||
|
||||
// Read and handle incoming messages.
|
||||
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
|
||||
var cmdMsg Command
|
||||
var task Task
|
||||
err := json.Unmarshal(message, &cmdMsg)
|
||||
if err == nil && cmdMsg.Command != 0 {
|
||||
return c.handleCommandMessage(cmdMsg)
|
||||
}
|
||||
err = json.Unmarshal(message, &task)
|
||||
if err == nil {
|
||||
return c.handleTaskMessage(ctx, task)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
|
||||
func (c *Consumer) AttemptConnect() error {
|
||||
var err error
|
||||
delay := c.opts.initialDelay
|
||||
|
||||
for i := 0; i < c.opts.maxRetries; i++ {
|
||||
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
||||
if err == nil {
|
||||
c.conn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
||||
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
||||
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
||||
time.Sleep(sleepDuration)
|
||||
delay *= 2
|
||||
if delay > c.opts.maxBackoff {
|
||||
@@ -137,20 +136,19 @@ func (c *Consumer) AttemptConnect() error {
|
||||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
||||
}
|
||||
|
||||
// onMessage reads incoming messages from the connection.
|
||||
func (c *Consumer) onMessage(ctx context.Context, conn net.Conn, message []byte) error {
|
||||
return c.readMessage(ctx, message)
|
||||
}
|
||||
|
||||
// onClose handles connection close event.
|
||||
func (c *Consumer) onClose(ctx context.Context, conn net.Conn) error {
|
||||
fmt.Println("Consumer Connection closed", c.id, conn.RemoteAddr())
|
||||
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||||
msg, err := c.receive(conn)
|
||||
if err == nil {
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
c.OnMessage(ctx, msg, conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
// onError handles errors while reading from the connection.
|
||||
func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) {
|
||||
fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr())
|
||||
}
|
||||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||
c.OnClose(ctx, conn)
|
||||
return err
|
||||
}
|
||||
c.OnError(ctx, conn, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Consume starts the consumer to consume tasks from the queues.
|
||||
@@ -159,26 +157,39 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, q := range c.queues {
|
||||
if err := c.subscribe(ctx, q); err != nil {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
||||
}
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
ReadFromConn(ctx, c.conn, Handlers{
|
||||
MessageHandler: c.opts.messageHandler,
|
||||
CloseHandler: c.opts.closeHandler,
|
||||
ErrorHandler: c.opts.errorHandler,
|
||||
})
|
||||
fmt.Println("Stopping consumer")
|
||||
for {
|
||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||
log.Println("Error reading message:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
for _, q := range c.queues {
|
||||
if err := c.subscribe(q); err != nil {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) waitForAck(conn net.Conn) error {
|
||||
msg, err := c.receive(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Command == consts.SUBSCRIBE_ACK {
|
||||
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a queue.
|
||||
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
||||
c.queues = append(c.queues, queue)
|
||||
|
118
ctx.go
118
ctx.go
@@ -1,36 +1,31 @@
|
||||
package mq
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xid"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type MessageHandler func(context.Context, net.Conn, []byte) error
|
||||
|
||||
type CloseHandler func(context.Context, net.Conn) error
|
||||
|
||||
type ErrorHandler func(context.Context, net.Conn, error)
|
||||
|
||||
type Handlers struct {
|
||||
MessageHandler MessageHandler
|
||||
CloseHandler CloseHandler
|
||||
ErrorHandler ErrorHandler
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ProcessedAt time.Time `json:"processed_at"`
|
||||
Status string `json:"status"`
|
||||
Error error `json:"error"`
|
||||
}
|
||||
|
||||
type Handler func(context.Context, Task) Result
|
||||
|
||||
func IsClosed(conn net.Conn) bool {
|
||||
_, err := conn.Read(make([]byte, 1))
|
||||
if err != nil {
|
||||
@@ -52,11 +47,31 @@ func SetHeaders(ctx context.Context, headers map[string]string) context.Context
|
||||
return context.WithValue(ctx, consts.HeaderKey, hd)
|
||||
}
|
||||
|
||||
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
|
||||
hd, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
hd = make(map[string]string)
|
||||
}
|
||||
for key, val := range headers {
|
||||
hd[key] = val
|
||||
}
|
||||
return hd
|
||||
}
|
||||
|
||||
func GetHeaders(ctx context.Context) (map[string]string, bool) {
|
||||
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||
return headers, ok
|
||||
}
|
||||
|
||||
func GetHeader(ctx context.Context, key string) (string, bool) {
|
||||
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
val, ok := headers[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func GetContentType(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
@@ -66,6 +81,15 @@ func GetContentType(ctx context.Context) (string, bool) {
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func GetQueue(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[consts.QueueKey]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func GetConsumerID(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
@@ -93,70 +117,6 @@ func GetPublisherID(ctx context.Context) (string, bool) {
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func Write(ctx context.Context, conn net.Conn, data any) error {
|
||||
msg := codec.Message{Headers: make(map[string]string)}
|
||||
if headers, ok := GetHeaders(ctx); ok {
|
||||
msg.Headers = headers
|
||||
}
|
||||
dataBytes, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg.Payload = dataBytes
|
||||
messageBytes, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = conn.Write(append(messageBytes, '\n'))
|
||||
return err
|
||||
}
|
||||
|
||||
func ReadFromConn(ctx context.Context, conn net.Conn, handlers Handlers) {
|
||||
defer func() {
|
||||
if handlers.CloseHandler != nil {
|
||||
if err := handlers.CloseHandler(ctx, conn); err != nil {
|
||||
fmt.Println("Error in close handler:", err)
|
||||
}
|
||||
}
|
||||
conn.Close()
|
||||
}()
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
messageBytes, err := reader.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF || IsClosed(conn) || strings.Contains(err.Error(), "closed network connection") {
|
||||
break
|
||||
}
|
||||
if handlers.ErrorHandler != nil {
|
||||
handlers.ErrorHandler(ctx, conn, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
messageBytes = bytes.TrimSpace(messageBytes)
|
||||
if len(messageBytes) == 0 {
|
||||
continue
|
||||
}
|
||||
var msg codec.Message
|
||||
err = json.Unmarshal(messageBytes, &msg)
|
||||
if err != nil {
|
||||
if handlers.ErrorHandler != nil {
|
||||
handlers.ErrorHandler(ctx, conn, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
if handlers.MessageHandler != nil {
|
||||
err = handlers.MessageHandler(ctx, conn, msg.Payload)
|
||||
if err != nil {
|
||||
if handlers.ErrorHandler != nil {
|
||||
handlers.ErrorHandler(ctx, conn, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewID() string {
|
||||
return xid.New().String()
|
||||
}
|
||||
|
137
dag/dag.go
137
dag/dag.go
@@ -3,13 +3,13 @@ package dag
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type taskContext struct {
|
||||
@@ -76,12 +76,20 @@ func (d *DAG) Start(ctx context.Context, addr string) error {
|
||||
if d.server.SyncMode() {
|
||||
return nil
|
||||
}
|
||||
for _, con := range d.nodes {
|
||||
go con.Consume(ctx)
|
||||
}
|
||||
go func() {
|
||||
d.server.Start(ctx)
|
||||
err := d.server.Start(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
for _, con := range d.nodes {
|
||||
go func(con *mq.Consumer) {
|
||||
err := con.Consume(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(con)
|
||||
}
|
||||
log.Printf("HTTP server started on %s", addr)
|
||||
config := d.server.TLSConfig()
|
||||
if config.UseTLS {
|
||||
@@ -90,16 +98,6 @@ func (d *DAG) Start(ctx context.Context, addr string) error {
|
||||
return http.ListenAndServe(addr, nil)
|
||||
}
|
||||
|
||||
func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) mq.Result {
|
||||
task := mq.Task{
|
||||
Payload: payload,
|
||||
}
|
||||
if len(taskID) > 0 {
|
||||
task.ID = taskID[0]
|
||||
}
|
||||
return d.server.Publish(ctx, task, queueName)
|
||||
}
|
||||
|
||||
func (d *DAG) FindFirstNode() (string, bool) {
|
||||
inDegree := make(map[string]int)
|
||||
for n, _ := range d.nodes {
|
||||
@@ -121,86 +119,23 @@ func (d *DAG) FindFirstNode() (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result {
|
||||
return d.sendSync(ctx, mq.Result{Payload: payload})
|
||||
}
|
||||
|
||||
func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result {
|
||||
if d.FirstNode == "" {
|
||||
return mq.Result{Error: fmt.Errorf("initial node not defined")}
|
||||
func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) error {
|
||||
queue, ok := mq.GetQueue(ctx)
|
||||
if !ok {
|
||||
queue = d.FirstNode
|
||||
}
|
||||
if d.server.SyncMode() {
|
||||
return d.sendSync(ctx, mq.Result{Payload: payload})
|
||||
var id string
|
||||
if len(taskID) > 0 {
|
||||
id = taskID[0]
|
||||
} else {
|
||||
id = mq.NewID()
|
||||
}
|
||||
resultCh := make(chan mq.Result)
|
||||
result := d.PublishTask(ctx, payload, d.FirstNode)
|
||||
if result.Error != nil {
|
||||
return result
|
||||
task := mq.Task{
|
||||
ID: id,
|
||||
Payload: payload,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.taskChMap[result.MessageID] = resultCh
|
||||
d.mu.Unlock()
|
||||
finalResult := <-resultCh
|
||||
return finalResult
|
||||
}
|
||||
|
||||
func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result {
|
||||
if con, ok := d.nodes[task.Queue]; ok {
|
||||
return con.ProcessTask(ctx, mq.Task{
|
||||
ID: task.MessageID,
|
||||
Payload: task.Payload,
|
||||
CurrentQueue: task.Queue,
|
||||
})
|
||||
}
|
||||
return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)}
|
||||
}
|
||||
|
||||
func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result {
|
||||
if task.MessageID == "" {
|
||||
task.MessageID = mq.NewID()
|
||||
}
|
||||
if task.Queue == "" {
|
||||
task.Queue = d.FirstNode
|
||||
}
|
||||
result := d.processNode(ctx, task)
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
for _, target := range d.loopEdges[task.Queue] {
|
||||
var items, results []json.RawMessage
|
||||
if err := json.Unmarshal(result.Payload, &items); err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
for _, item := range items {
|
||||
result = d.sendSync(ctx, mq.Result{
|
||||
Command: result.Command,
|
||||
Payload: item,
|
||||
Queue: target,
|
||||
MessageID: result.MessageID,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
results = append(results, result.Payload)
|
||||
}
|
||||
bt, err := json.Marshal(results)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
result.Payload = bt
|
||||
}
|
||||
if target, ok := d.edges[task.Queue]; ok {
|
||||
result = d.sendSync(ctx, mq.Result{
|
||||
Command: result.Command,
|
||||
Payload: result.Payload,
|
||||
Queue: target,
|
||||
MessageID: result.MessageID,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return result
|
||||
}
|
||||
}
|
||||
return result
|
||||
return d.server.Publish(ctx, task, queue)
|
||||
}
|
||||
|
||||
func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) {
|
||||
@@ -264,9 +199,12 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue})
|
||||
for _, loopNode := range loopNodes {
|
||||
for _, item := range items {
|
||||
rs := d.PublishTask(ctx, item, loopNode, task.MessageID)
|
||||
if rs.Error != nil {
|
||||
return rs
|
||||
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||
consts.QueueKey: loopNode,
|
||||
})
|
||||
err := d.PublishTask(ctx, item, task.MessageID)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -284,15 +222,14 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
||||
totalItems: 1,
|
||||
},
|
||||
}
|
||||
rs := d.PublishTask(ctx, payload, edge, task.MessageID)
|
||||
if rs.Error != nil {
|
||||
return rs
|
||||
err := d.PublishTask(ctx, payload, edge, task.MessageID)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
}
|
||||
} else if completed {
|
||||
d.mu.Lock()
|
||||
if resultCh, ok := d.taskChMap[task.MessageID]; ok {
|
||||
resultCh <- mq.Result{
|
||||
Command: "complete",
|
||||
Payload: payload,
|
||||
Queue: task.Queue,
|
||||
MessageID: task.MessageID,
|
||||
|
@@ -2,9 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/oarkflow/mq"
|
||||
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
mq "github.com/oarkflow/mq/v2"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
@@ -2,19 +2,15 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/dag"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
"time"
|
||||
)
|
||||
|
||||
var d *dag.DAG
|
||||
|
||||
func main() {
|
||||
d = dag.New(mq.WithTLS(true, "server.crt", "server.key"), mq.WithCAPath("ca.crt"))
|
||||
d = dag.New()
|
||||
d.AddNode("queue1", tasks.Node1)
|
||||
d.AddNode("queue2", tasks.Node2)
|
||||
d.AddNode("queue3", tasks.Node3)
|
||||
@@ -24,45 +20,14 @@ func main() {
|
||||
d.AddLoop("queue2", "queue3")
|
||||
d.AddEdge("queue2", "queue4")
|
||||
d.Prepare()
|
||||
http.HandleFunc("POST /publish", requestHandler("publish"))
|
||||
http.HandleFunc("POST /request", requestHandler("request"))
|
||||
err := d.Start(context.TODO(), ":8083")
|
||||
go func() {
|
||||
d.Start(context.Background(), ":8081")
|
||||
}()
|
||||
time.Sleep(5 * time.Second)
|
||||
err := d.PublishTask(context.Background(), []byte(`{"tast": 123}`))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var payload []byte
|
||||
if r.Body != nil {
|
||||
defer r.Body.Close()
|
||||
var err error
|
||||
payload, err = io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
http.Error(w, "Empty request body", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var rs mq.Result
|
||||
if requestType == "request" {
|
||||
rs = d.Request(context.Background(), payload)
|
||||
} else {
|
||||
rs = d.Send(context.Background(), payload)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
result := map[string]any{
|
||||
"message_id": rs.MessageID,
|
||||
"payload": string(rs.Payload),
|
||||
"error": rs.Error,
|
||||
}
|
||||
json.NewEncoder(w).Encode(result)
|
||||
}
|
||||
time.Sleep(10 * time.Second)
|
||||
}
|
||||
|
@@ -3,17 +3,16 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
mq2 "github.com/oarkflow/mq"
|
||||
"time"
|
||||
|
||||
mq "github.com/oarkflow/mq/v2"
|
||||
)
|
||||
|
||||
func main() {
|
||||
payload := []byte(`{"message":"Message Publisher \n Task"}`)
|
||||
task := mq.Task{
|
||||
task := mq2.Task{
|
||||
Payload: payload,
|
||||
}
|
||||
publisher := mq.NewPublisher("publish-1")
|
||||
publisher := mq2.NewPublisher("publish-1")
|
||||
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
||||
err := publisher.Publish(context.Background(), task, "queue1")
|
||||
if err != nil {
|
||||
@@ -21,7 +20,7 @@ func main() {
|
||||
}
|
||||
fmt.Println("Async task published successfully")
|
||||
payload = []byte(`{"message":"Fire-and-Forget \n Task"}`)
|
||||
task = mq.Task{
|
||||
task = mq2.Task{
|
||||
Payload: payload,
|
||||
}
|
||||
for i := 0; i < 100; i++ {
|
||||
|
@@ -2,13 +2,13 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
mq2 "github.com/oarkflow/mq"
|
||||
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
mq "github.com/oarkflow/mq/v2"
|
||||
)
|
||||
|
||||
func main() {
|
||||
b := mq.NewBroker(mq.WithCallback(tasks.Callback))
|
||||
b := mq2.NewBroker(mq2.WithCallback(tasks.Callback))
|
||||
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
||||
b.NewQueue("queue1")
|
||||
b.NewQueue("queue2")
|
||||
|
@@ -4,42 +4,41 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
mq "github.com/oarkflow/mq/v2"
|
||||
mq2 "github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
func Node1(ctx context.Context, task mq.Task) mq.Result {
|
||||
func Node1(ctx context.Context, task mq2.Task) mq2.Result {
|
||||
fmt.Println("Processing queue1", task.ID)
|
||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
return mq2.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
}
|
||||
|
||||
func Node2(ctx context.Context, task mq.Task) mq.Result {
|
||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
func Node2(ctx context.Context, task mq2.Task) mq2.Result {
|
||||
return mq2.Result{Payload: task.Payload, MessageID: task.ID}
|
||||
}
|
||||
|
||||
func Node3(ctx context.Context, task mq.Task) mq.Result {
|
||||
func Node3(ctx context.Context, task mq2.Task) mq2.Result {
|
||||
var data map[string]any
|
||||
err := json.Unmarshal(task.Payload, &data)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
return mq2.Result{Error: err}
|
||||
}
|
||||
data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
|
||||
bt, _ := json.Marshal(data)
|
||||
return mq.Result{Payload: bt, MessageID: task.ID}
|
||||
return mq2.Result{Payload: bt, MessageID: task.ID}
|
||||
}
|
||||
|
||||
func Node4(ctx context.Context, task mq.Task) mq.Result {
|
||||
func Node4(ctx context.Context, task mq2.Task) mq2.Result {
|
||||
var data []map[string]any
|
||||
err := json.Unmarshal(task.Payload, &data)
|
||||
if err != nil {
|
||||
return mq.Result{Error: err}
|
||||
return mq2.Result{Error: err}
|
||||
}
|
||||
payload := map[string]any{"storage": data}
|
||||
bt, _ := json.Marshal(payload)
|
||||
return mq.Result{Payload: bt, MessageID: task.ID}
|
||||
return mq2.Result{Payload: bt, MessageID: task.ID}
|
||||
}
|
||||
|
||||
func Callback(ctx context.Context, task mq.Result) mq.Result {
|
||||
func Callback(ctx context.Context, task mq2.Result) mq2.Result {
|
||||
fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue)
|
||||
return mq.Result{}
|
||||
return mq2.Result{}
|
||||
}
|
||||
|
69
options.go
69
options.go
@@ -2,9 +2,18 @@ package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
Queue string `json:"queue"`
|
||||
MessageID string `json:"message_id"`
|
||||
Error error `json:"error,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
UseTLS bool
|
||||
CertPath string
|
||||
@@ -15,15 +24,16 @@ type TLSConfig struct {
|
||||
type Options struct {
|
||||
syncMode bool
|
||||
brokerAddr string
|
||||
messageHandler MessageHandler
|
||||
closeHandler CloseHandler
|
||||
errorHandler ErrorHandler
|
||||
callback []func(context.Context, Result) Result
|
||||
maxRetries int
|
||||
initialDelay time.Duration
|
||||
maxBackoff time.Duration
|
||||
jitterPercent float64
|
||||
tlsConfig TLSConfig
|
||||
aesKey json.RawMessage
|
||||
hmacKey json.RawMessage
|
||||
enableEncryption bool
|
||||
queueSize int
|
||||
}
|
||||
|
||||
func defaultOptions() Options {
|
||||
@@ -34,27 +44,29 @@ func defaultOptions() Options {
|
||||
initialDelay: 2 * time.Second,
|
||||
maxBackoff: 20 * time.Second,
|
||||
jitterPercent: 0.5,
|
||||
queueSize: 100,
|
||||
}
|
||||
}
|
||||
|
||||
func defaultHandlers(options Options, onMessage MessageHandler, onClose CloseHandler, onError ErrorHandler) Options {
|
||||
if options.messageHandler == nil {
|
||||
options.messageHandler = onMessage
|
||||
}
|
||||
|
||||
if options.closeHandler == nil {
|
||||
options.closeHandler = onClose
|
||||
}
|
||||
|
||||
if options.errorHandler == nil {
|
||||
options.errorHandler = onError
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
// Option defines a function type for setting options.
|
||||
type Option func(*Options)
|
||||
|
||||
func setupOptions(opts ...Option) Options {
|
||||
options := defaultOptions()
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
|
||||
return func(opts *Options) {
|
||||
opts.aesKey = aesKey
|
||||
opts.hmacKey = hmacKey
|
||||
opts.enableEncryption = enableEncryption
|
||||
}
|
||||
}
|
||||
|
||||
// WithBrokerURL -
|
||||
func WithBrokerURL(url string) Option {
|
||||
return func(opts *Options) {
|
||||
@@ -119,24 +131,3 @@ func WithJitterPercent(val float64) Option {
|
||||
opts.jitterPercent = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithMessageHandler sets a custom MessageHandler.
|
||||
func WithMessageHandler(handler MessageHandler) Option {
|
||||
return func(opts *Options) {
|
||||
opts.messageHandler = handler
|
||||
}
|
||||
}
|
||||
|
||||
// WithErrorHandler sets a custom ErrorHandler.
|
||||
func WithErrorHandler(handler ErrorHandler) Option {
|
||||
return func(opts *Options) {
|
||||
opts.errorHandler = handler
|
||||
}
|
||||
}
|
||||
|
||||
// WithCloseHandler sets a custom CloseHandler.
|
||||
func WithCloseHandler(handler CloseHandler) Option {
|
||||
return func(opts *Options) {
|
||||
opts.closeHandler = handler
|
||||
}
|
||||
}
|
||||
|
100
publisher.go
100
publisher.go
@@ -4,9 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/jsonparser"
|
||||
)
|
||||
|
||||
type Publisher struct {
|
||||
@@ -15,31 +19,59 @@ type Publisher struct {
|
||||
}
|
||||
|
||||
func NewPublisher(id string, opts ...Option) *Publisher {
|
||||
options := defaultOptions()
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
b := &Publisher{id: id}
|
||||
b.opts = defaultHandlers(options, nil, b.onClose, b.onError)
|
||||
return b
|
||||
options := setupOptions(opts...)
|
||||
return &Publisher{id: id, opts: options}
|
||||
}
|
||||
|
||||
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
||||
ctx = SetHeaders(ctx, map[string]string{
|
||||
headers := WithHeaders(ctx, map[string]string{
|
||||
consts.PublisherKey: p.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
cmd := Command{
|
||||
ID: NewID(),
|
||||
Command: command,
|
||||
Queue: queue,
|
||||
MessageID: task.ID,
|
||||
Payload: task.Payload,
|
||||
if task.ID == "" {
|
||||
task.ID = NewID()
|
||||
}
|
||||
return Write(ctx, conn, cmd)
|
||||
task.CreatedAt = time.Now()
|
||||
payload, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := codec.NewMessage(command, payload, queue, headers)
|
||||
if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.waitForAck(conn)
|
||||
}
|
||||
|
||||
func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error {
|
||||
func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Command == consts.PUBLISH_ACK {
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||||
}
|
||||
|
||||
func (p *Publisher) waitForResponse(conn net.Conn) Result {
|
||||
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
}
|
||||
if msg.Command == consts.RESPONSE {
|
||||
var result Result
|
||||
err = json.Unmarshal(msg.Payload, &result)
|
||||
return result
|
||||
}
|
||||
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
|
||||
return Result{Error: err}
|
||||
}
|
||||
|
||||
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
|
||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to broker: %w", err)
|
||||
@@ -57,30 +89,22 @@ func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
|
||||
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Result, error) {
|
||||
func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result {
|
||||
ctx = SetHeaders(ctx, map[string]string{
|
||||
consts.AwaitResponseKey: "true",
|
||||
})
|
||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||
if err != nil {
|
||||
return Result{Error: err}, fmt.Errorf("failed to connect to broker: %w", err)
|
||||
err = fmt.Errorf("failed to connect to broker: %w", err)
|
||||
return Result{Error: err}
|
||||
}
|
||||
defer conn.Close()
|
||||
var result Result
|
||||
err = p.send(ctx, queue, task, conn, consts.REQUEST)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if p.opts.messageHandler == nil {
|
||||
p.opts.messageHandler = func(ctx context.Context, conn net.Conn, message []byte) error {
|
||||
err := json.Unmarshal(message, &result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Close()
|
||||
}
|
||||
}
|
||||
ReadFromConn(ctx, conn, Handlers{
|
||||
MessageHandler: p.opts.messageHandler,
|
||||
CloseHandler: p.opts.closeHandler,
|
||||
ErrorHandler: p.opts.errorHandler,
|
||||
})
|
||||
return result, nil
|
||||
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||
resultCh := make(chan Result)
|
||||
go func() {
|
||||
defer close(resultCh)
|
||||
resultCh <- p.waitForResponse(conn)
|
||||
}()
|
||||
finalResult := <-resultCh
|
||||
return finalResult
|
||||
}
|
||||
|
@@ -1,4 +1,4 @@
|
||||
package v2
|
||||
package mq
|
||||
|
||||
import (
|
||||
"github.com/oarkflow/xsync"
|
@@ -1,4 +1,4 @@
|
||||
package v2
|
||||
package mq
|
||||
|
||||
import (
|
||||
"context"
|
323
v2/broker.go
323
v2/broker.go
@@ -1,323 +0,0 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xsync"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/jsonparser"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
type QueuedTask struct {
|
||||
Message *codec.Message
|
||||
RetryCount int
|
||||
}
|
||||
|
||||
type consumer struct {
|
||||
id string
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
type publisher struct {
|
||||
id string
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
type Broker struct {
|
||||
queues xsync.IMap[string, *Queue]
|
||||
consumers xsync.IMap[string, *consumer]
|
||||
publishers xsync.IMap[string, *publisher]
|
||||
opts Options
|
||||
}
|
||||
|
||||
func NewBroker(opts ...Option) *Broker {
|
||||
options := setupOptions(opts...)
|
||||
return &Broker{
|
||||
queues: xsync.NewMap[string, *Queue](),
|
||||
publishers: xsync.NewMap[string, *publisher](),
|
||||
consumers: xsync.NewMap[string, *consumer](),
|
||||
opts: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
if ok && consumerID != "" {
|
||||
if con, exists := b.consumers.Get(consumerID); exists {
|
||||
con.conn.Close()
|
||||
b.consumers.Del(consumerID)
|
||||
}
|
||||
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
||||
queue.consumers.Del(consumerID)
|
||||
return true
|
||||
})
|
||||
}
|
||||
publisherID, ok := GetPublisherID(ctx)
|
||||
if ok && publisherID != "" {
|
||||
if con, exists := b.publishers.Get(publisherID); exists {
|
||||
con.conn.Close()
|
||||
b.publishers.Del(publisherID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
||||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
switch msg.Command {
|
||||
case consts.PUBLISH:
|
||||
b.PublishHandler(ctx, conn, msg)
|
||||
case consts.SUBSCRIBE:
|
||||
b.SubscribeHandler(ctx, conn, msg)
|
||||
case consts.MESSAGE_RESPONSE:
|
||||
b.MessageResponseHandler(ctx, msg)
|
||||
case consts.MESSAGE_ACK:
|
||||
b.MessageAck(ctx, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
|
||||
consumerID, _ := GetConsumerID(ctx)
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
|
||||
}
|
||||
|
||||
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
|
||||
msg.Command = consts.RESPONSE
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
b.HandleCallback(ctx, msg)
|
||||
awaitResponse, ok := headers[consts.AwaitResponseKey]
|
||||
if !(ok && awaitResponse == "true") {
|
||||
return
|
||||
}
|
||||
publisherID, exists := headers[consts.PublisherKey]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
con, ok := b.publishers.Get(publisherID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
err := b.send(con.conn, msg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) Publish(ctx context.Context, task Task, queue string) error {
|
||||
headers, _ := GetHeaders(ctx)
|
||||
msg := codec.NewMessage(consts.PUBLISH, task.Payload, queue, headers)
|
||||
b.broadcastToConsumers(ctx, msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
pub := b.addPublisher(ctx, msg.Queue, conn)
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
||||
|
||||
ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||||
if err := b.send(conn, ack); err != nil {
|
||||
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||||
}
|
||||
b.broadcastToConsumers(ctx, msg)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.publishers.Del(pub.id)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
consumerID := b.addConsumer(ctx, msg.Queue, conn)
|
||||
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||||
if err := b.send(conn, ack); err != nil {
|
||||
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.removeConsumer(msg.Queue, consumerID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (b *Broker) Start(ctx context.Context) error {
|
||||
var listener net.Listener
|
||||
var err error
|
||||
|
||||
if b.opts.tlsConfig.UseTLS {
|
||||
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load TLS certificates: %v", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start TLS listener: %v", err)
|
||||
}
|
||||
log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr)
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start TCP listener: %v", err)
|
||||
}
|
||||
log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr)
|
||||
}
|
||||
defer listener.Close()
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
b.OnError(ctx, conn, err)
|
||||
continue
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
for {
|
||||
err := b.readMessage(ctx, c)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||
}
|
||||
|
||||
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||
}
|
||||
|
||||
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
||||
if queue, ok := b.queues.Get(msg.Queue); ok {
|
||||
task := &QueuedTask{Message: msg, RetryCount: 0}
|
||||
queue.tasks <- task
|
||||
log.Printf("Task enqueued for queue %s", msg.Queue)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
|
||||
msg, err := b.receive(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Command == consts.MESSAGE_ACK {
|
||||
log.Println("Received CONSUMER_ACK: Subscribed successfully")
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
|
||||
}
|
||||
|
||||
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
|
||||
publisherID, ok := GetPublisherID(ctx)
|
||||
_, ok = b.queues.Get(queueName)
|
||||
if !ok {
|
||||
b.NewQueue(queueName)
|
||||
}
|
||||
con := &publisher{id: publisherID, conn: conn}
|
||||
b.publishers.Set(publisherID, con)
|
||||
return con
|
||||
}
|
||||
|
||||
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
q, ok := b.queues.Get(queueName)
|
||||
if !ok {
|
||||
q = b.NewQueue(queueName)
|
||||
}
|
||||
con := &consumer{id: consumerID, conn: conn}
|
||||
b.consumers.Set(consumerID, con)
|
||||
q.consumers.Set(consumerID, con)
|
||||
log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName)
|
||||
return consumerID
|
||||
}
|
||||
|
||||
func (b *Broker) removeConsumer(queueName, consumerID string) {
|
||||
if queue, ok := b.queues.Get(queueName); ok {
|
||||
con, ok := queue.consumers.Get(consumerID)
|
||||
if ok {
|
||||
con.conn.Close()
|
||||
queue.consumers.Del(consumerID)
|
||||
}
|
||||
b.queues.Del(queueName)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||
msg, err := b.receive(c)
|
||||
if err == nil {
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
b.OnMessage(ctx, msg, c)
|
||||
return nil
|
||||
}
|
||||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||
b.OnClose(ctx, c)
|
||||
return err
|
||||
}
|
||||
b.OnError(ctx, c, err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Broker) dispatchWorker(queue *Queue) {
|
||||
delay := b.opts.initialDelay
|
||||
for task := range queue.tasks {
|
||||
success := false
|
||||
for !success && task.RetryCount <= b.opts.maxRetries {
|
||||
if b.dispatchTaskToConsumer(queue, task) {
|
||||
success = true
|
||||
} else {
|
||||
task.RetryCount++
|
||||
delay = b.backoffRetry(queue, task, delay)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
||||
var consumerFound bool
|
||||
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
||||
if err := b.send(con.conn, task.Message); err == nil {
|
||||
consumerFound = true
|
||||
log.Printf("Task dispatched to consumer %s on queue %s", con.id, queue.name)
|
||||
return false // break the loop once a consumer is found
|
||||
}
|
||||
return true
|
||||
})
|
||||
if !consumerFound {
|
||||
log.Printf("No available consumers for queue %s, retrying...", queue.name)
|
||||
}
|
||||
return consumerFound
|
||||
}
|
||||
|
||||
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
|
||||
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
|
||||
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
|
||||
time.Sleep(backoffDuration)
|
||||
queue.tasks <- task
|
||||
delay *= 2
|
||||
if delay > b.opts.maxBackoff {
|
||||
delay = b.opts.maxBackoff
|
||||
}
|
||||
return delay
|
||||
}
|
197
v2/consumer.go
197
v2/consumer.go
@@ -1,197 +0,0 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/jsonparser"
|
||||
"github.com/oarkflow/mq/utils"
|
||||
)
|
||||
|
||||
// Consumer structure to hold consumer-specific configurations and state.
|
||||
type Consumer struct {
|
||||
id string
|
||||
handlers map[string]Handler
|
||||
conn net.Conn
|
||||
queues []string
|
||||
opts Options
|
||||
}
|
||||
|
||||
// NewConsumer initializes a new consumer with the provided options.
|
||||
func NewConsumer(id string, opts ...Option) *Consumer {
|
||||
options := setupOptions(opts...)
|
||||
return &Consumer{
|
||||
handlers: make(map[string]Handler),
|
||||
id: id,
|
||||
opts: options,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||
}
|
||||
|
||||
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||
}
|
||||
|
||||
// Close closes the consumer's connection.
|
||||
func (c *Consumer) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
// Subscribe to a specific queue.
|
||||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||
headers := WithHeaders(ctx, map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers)
|
||||
if err := c.send(c.conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.waitForAck(c.conn)
|
||||
}
|
||||
|
||||
func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error {
|
||||
fmt.Println("Consumer closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
||||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
headers := WithHeaders(ctx, map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||
if err := c.send(conn, reply); err != nil {
|
||||
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||||
}
|
||||
var task Task
|
||||
err := json.Unmarshal(msg.Payload, &task)
|
||||
if err != nil {
|
||||
log.Println("Error unmarshalling message:", err)
|
||||
return
|
||||
}
|
||||
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
||||
result := c.ProcessTask(ctx, task)
|
||||
result.MessageID = task.ID
|
||||
result.Queue = msg.Queue
|
||||
if result.Error != nil {
|
||||
result.Status = "FAILED"
|
||||
} else {
|
||||
result.Status = "SUCCESS"
|
||||
}
|
||||
bt, _ := json.Marshal(result)
|
||||
reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers)
|
||||
if err := c.send(conn, reply); err != nil {
|
||||
fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err)
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessTask handles a received task message and invokes the appropriate handler.
|
||||
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
||||
queue, _ := GetQueue(ctx)
|
||||
handler, exists := c.handlers[queue]
|
||||
if !exists {
|
||||
return Result{Error: errors.New("No handler for queue " + queue)}
|
||||
}
|
||||
return handler(ctx, msg)
|
||||
}
|
||||
|
||||
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
|
||||
func (c *Consumer) AttemptConnect() error {
|
||||
var err error
|
||||
delay := c.opts.initialDelay
|
||||
for i := 0; i < c.opts.maxRetries; i++ {
|
||||
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
||||
if err == nil {
|
||||
c.conn = conn
|
||||
return nil
|
||||
}
|
||||
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
||||
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
||||
time.Sleep(sleepDuration)
|
||||
delay *= 2
|
||||
if delay > c.opts.maxBackoff {
|
||||
delay = c.opts.maxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
||||
}
|
||||
|
||||
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||||
msg, err := c.receive(conn)
|
||||
if err == nil {
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
c.OnMessage(ctx, msg, conn)
|
||||
return nil
|
||||
}
|
||||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||
c.OnClose(ctx, conn)
|
||||
return err
|
||||
}
|
||||
c.OnError(ctx, conn, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Consume starts the consumer to consume tasks from the queues.
|
||||
func (c *Consumer) Consume(ctx context.Context) error {
|
||||
err := c.AttemptConnect()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, q := range c.queues {
|
||||
if err := c.subscribe(ctx, q); err != nil {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
||||
}
|
||||
}
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||
log.Println("Error reading message:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) waitForAck(conn net.Conn) error {
|
||||
msg, err := c.receive(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Command == consts.SUBSCRIBE_ACK {
|
||||
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a queue.
|
||||
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
||||
c.queues = append(c.queues, queue)
|
||||
c.handlers[queue] = handler
|
||||
}
|
158
v2/ctx.go
158
v2/ctx.go
@@ -1,158 +0,0 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xid"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
ID string `json:"id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ProcessedAt time.Time `json:"processed_at"`
|
||||
Status string `json:"status"`
|
||||
Error error `json:"error"`
|
||||
}
|
||||
|
||||
type Handler func(context.Context, Task) Result
|
||||
|
||||
func IsClosed(conn net.Conn) bool {
|
||||
_, err := conn.Read(make([]byte, 1))
|
||||
if err != nil {
|
||||
if err == net.ErrClosed {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func SetHeaders(ctx context.Context, headers map[string]string) context.Context {
|
||||
hd, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
hd = make(map[string]string)
|
||||
}
|
||||
for key, val := range headers {
|
||||
hd[key] = val
|
||||
}
|
||||
return context.WithValue(ctx, consts.HeaderKey, hd)
|
||||
}
|
||||
|
||||
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
|
||||
hd, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
hd = make(map[string]string)
|
||||
}
|
||||
for key, val := range headers {
|
||||
hd[key] = val
|
||||
}
|
||||
return hd
|
||||
}
|
||||
|
||||
func GetHeaders(ctx context.Context) (map[string]string, bool) {
|
||||
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||
return headers, ok
|
||||
}
|
||||
|
||||
func GetHeader(ctx context.Context, key string) (string, bool) {
|
||||
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
val, ok := headers[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func GetContentType(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[consts.ContentType]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func GetQueue(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[consts.QueueKey]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func GetConsumerID(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[consts.ConsumerKey]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func GetTriggerNode(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[consts.TriggerNode]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func GetPublisherID(ctx context.Context) (string, bool) {
|
||||
headers, ok := GetHeaders(ctx)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
contentType, ok := headers[consts.PublisherKey]
|
||||
return contentType, ok
|
||||
}
|
||||
|
||||
func NewID() string {
|
||||
return xid.New().String()
|
||||
}
|
||||
|
||||
func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
if len(caPath) > 0 && caPath[0] != "" {
|
||||
caCert, err := os.ReadFile(caPath[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load CA cert: %w", err)
|
||||
}
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
tlsConfig.ClientCAs = caCertPool
|
||||
}
|
||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial TLS connection: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
|
||||
if config.UseTLS {
|
||||
return createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath)
|
||||
} else {
|
||||
return net.Dial("tcp", addr)
|
||||
}
|
||||
}
|
133
v2/options.go
133
v2/options.go
@@ -1,133 +0,0 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
Queue string `json:"queue"`
|
||||
MessageID string `json:"message_id"`
|
||||
Error error `json:"error,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type TLSConfig struct {
|
||||
UseTLS bool
|
||||
CertPath string
|
||||
KeyPath string
|
||||
CAPath string
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
syncMode bool
|
||||
brokerAddr string
|
||||
callback []func(context.Context, Result) Result
|
||||
maxRetries int
|
||||
initialDelay time.Duration
|
||||
maxBackoff time.Duration
|
||||
jitterPercent float64
|
||||
tlsConfig TLSConfig
|
||||
aesKey json.RawMessage
|
||||
hmacKey json.RawMessage
|
||||
enableEncryption bool
|
||||
queueSize int
|
||||
}
|
||||
|
||||
func defaultOptions() Options {
|
||||
return Options{
|
||||
syncMode: false,
|
||||
brokerAddr: ":8080",
|
||||
maxRetries: 5,
|
||||
initialDelay: 2 * time.Second,
|
||||
maxBackoff: 20 * time.Second,
|
||||
jitterPercent: 0.5,
|
||||
queueSize: 100,
|
||||
}
|
||||
}
|
||||
|
||||
// Option defines a function type for setting options.
|
||||
type Option func(*Options)
|
||||
|
||||
func setupOptions(opts ...Option) Options {
|
||||
options := defaultOptions()
|
||||
for _, opt := range opts {
|
||||
opt(&options)
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
|
||||
return func(opts *Options) {
|
||||
opts.aesKey = aesKey
|
||||
opts.hmacKey = hmacKey
|
||||
opts.enableEncryption = enableEncryption
|
||||
}
|
||||
}
|
||||
|
||||
// WithBrokerURL -
|
||||
func WithBrokerURL(url string) Option {
|
||||
return func(opts *Options) {
|
||||
opts.brokerAddr = url
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLS - Option to enable/disable TLS
|
||||
func WithTLS(enableTLS bool, certPath, keyPath string) Option {
|
||||
return func(o *Options) {
|
||||
o.tlsConfig.UseTLS = enableTLS
|
||||
o.tlsConfig.CertPath = certPath
|
||||
o.tlsConfig.KeyPath = keyPath
|
||||
}
|
||||
}
|
||||
|
||||
// WithCAPath - Option to enable/disable TLS
|
||||
func WithCAPath(caPath string) Option {
|
||||
return func(o *Options) {
|
||||
o.tlsConfig.CAPath = caPath
|
||||
}
|
||||
}
|
||||
|
||||
// WithSyncMode -
|
||||
func WithSyncMode(mode bool) Option {
|
||||
return func(opts *Options) {
|
||||
opts.syncMode = mode
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxRetries -
|
||||
func WithMaxRetries(val int) Option {
|
||||
return func(opts *Options) {
|
||||
opts.maxRetries = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithInitialDelay -
|
||||
func WithInitialDelay(val time.Duration) Option {
|
||||
return func(opts *Options) {
|
||||
opts.initialDelay = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithMaxBackoff -
|
||||
func WithMaxBackoff(val time.Duration) Option {
|
||||
return func(opts *Options) {
|
||||
opts.maxBackoff = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithCallback -
|
||||
func WithCallback(val ...func(context.Context, Result) Result) Option {
|
||||
return func(opts *Options) {
|
||||
opts.callback = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithJitterPercent -
|
||||
func WithJitterPercent(val float64) Option {
|
||||
return func(opts *Options) {
|
||||
opts.jitterPercent = val
|
||||
}
|
||||
}
|
110
v2/publisher.go
110
v2/publisher.go
@@ -1,110 +0,0 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
"github.com/oarkflow/mq/consts"
|
||||
"github.com/oarkflow/mq/jsonparser"
|
||||
)
|
||||
|
||||
type Publisher struct {
|
||||
id string
|
||||
opts Options
|
||||
}
|
||||
|
||||
func NewPublisher(id string, opts ...Option) *Publisher {
|
||||
options := setupOptions(opts...)
|
||||
return &Publisher{id: id, opts: options}
|
||||
}
|
||||
|
||||
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
||||
headers := WithHeaders(ctx, map[string]string{
|
||||
consts.PublisherKey: p.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
if task.ID == "" {
|
||||
task.ID = NewID()
|
||||
}
|
||||
task.CreatedAt = time.Now()
|
||||
payload, err := json.Marshal(task)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
msg := codec.NewMessage(command, payload, queue, headers)
|
||||
if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.waitForAck(conn)
|
||||
}
|
||||
|
||||
func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if msg.Command == consts.PUBLISH_ACK {
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||||
}
|
||||
|
||||
func (p *Publisher) waitForResponse(conn net.Conn) Result {
|
||||
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
}
|
||||
if msg.Command == consts.RESPONSE {
|
||||
var result Result
|
||||
err = json.Unmarshal(msg.Payload, &result)
|
||||
return result
|
||||
}
|
||||
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
|
||||
return Result{Error: err}
|
||||
}
|
||||
|
||||
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
|
||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to broker: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
return p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||
}
|
||||
|
||||
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error {
|
||||
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
|
||||
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result {
|
||||
ctx = SetHeaders(ctx, map[string]string{
|
||||
consts.AwaitResponseKey: "true",
|
||||
})
|
||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to connect to broker: %w", err)
|
||||
return Result{Error: err}
|
||||
}
|
||||
defer conn.Close()
|
||||
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||
resultCh := make(chan Result)
|
||||
go func() {
|
||||
defer close(resultCh)
|
||||
resultCh <- p.waitForResponse(conn)
|
||||
}()
|
||||
finalResult := <-resultCh
|
||||
return finalResult
|
||||
}
|
Reference in New Issue
Block a user