mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-05 16:06:55 +08:00
feat: separate broker
This commit is contained in:
448
broker.go
448
broker.go
@@ -4,37 +4,35 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/xsync"
|
"github.com/oarkflow/xsync"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/codec"
|
||||||
"github.com/oarkflow/mq/consts"
|
"github.com/oarkflow/mq/consts"
|
||||||
|
"github.com/oarkflow/mq/jsonparser"
|
||||||
|
"github.com/oarkflow/mq/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type QueuedTask struct {
|
||||||
|
Message *codec.Message
|
||||||
|
RetryCount int
|
||||||
|
}
|
||||||
|
|
||||||
type consumer struct {
|
type consumer struct {
|
||||||
id string
|
id string
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *consumer) send(ctx context.Context, cmd any) error {
|
|
||||||
return Write(ctx, p.conn, cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
type publisher struct {
|
type publisher struct {
|
||||||
id string
|
id string
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *publisher) send(ctx context.Context, cmd any) error {
|
|
||||||
return Write(ctx, p.conn, cmd)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Handler func(context.Context, Task) Result
|
|
||||||
|
|
||||||
type Broker struct {
|
type Broker struct {
|
||||||
queues xsync.IMap[string, *Queue]
|
queues xsync.IMap[string, *Queue]
|
||||||
consumers xsync.IMap[string, *consumer]
|
consumers xsync.IMap[string, *consumer]
|
||||||
@@ -42,100 +40,17 @@ type Broker struct {
|
|||||||
opts Options
|
opts Options
|
||||||
}
|
}
|
||||||
|
|
||||||
type Queue struct {
|
|
||||||
name string
|
|
||||||
consumers xsync.IMap[string, *consumer]
|
|
||||||
messages xsync.IMap[string, *Task]
|
|
||||||
deferred xsync.IMap[string, *Task]
|
|
||||||
}
|
|
||||||
|
|
||||||
func newQueue(name string) *Queue {
|
|
||||||
return &Queue{
|
|
||||||
name: name,
|
|
||||||
consumers: xsync.NewMap[string, *consumer](),
|
|
||||||
messages: xsync.NewMap[string, *Task](),
|
|
||||||
deferred: xsync.NewMap[string, *Task](),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (queue *Queue) send(ctx context.Context, cmd any) {
|
|
||||||
queue.consumers.ForEach(func(_ string, client *consumer) bool {
|
|
||||||
err := client.send(ctx, cmd)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type Task struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Payload json.RawMessage `json:"payload"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
ProcessedAt time.Time `json:"processed_at"`
|
|
||||||
CurrentQueue string `json:"current_queue"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Error error `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Command struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Command consts.CMD `json:"command"`
|
|
||||||
Queue string `json:"queue"`
|
|
||||||
MessageID string `json:"message_id"`
|
|
||||||
Payload json.RawMessage `json:"payload,omitempty"` // Used for carrying the task payload
|
|
||||||
Error string `json:"error,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Result struct {
|
|
||||||
Command string `json:"command"`
|
|
||||||
Payload json.RawMessage `json:"payload"`
|
|
||||||
Queue string `json:"queue"`
|
|
||||||
MessageID string `json:"message_id"`
|
|
||||||
Error error `json:"error"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBroker(opts ...Option) *Broker {
|
func NewBroker(opts ...Option) *Broker {
|
||||||
options := defaultOptions()
|
options := setupOptions(opts...)
|
||||||
for _, opt := range opts {
|
return &Broker{
|
||||||
opt(&options)
|
|
||||||
}
|
|
||||||
b := &Broker{
|
|
||||||
queues: xsync.NewMap[string, *Queue](),
|
queues: xsync.NewMap[string, *Queue](),
|
||||||
publishers: xsync.NewMap[string, *publisher](),
|
publishers: xsync.NewMap[string, *publisher](),
|
||||||
consumers: xsync.NewMap[string, *consumer](),
|
consumers: xsync.NewMap[string, *consumer](),
|
||||||
|
opts: options,
|
||||||
}
|
}
|
||||||
b.opts = defaultHandlers(options, b.onMessage, b.onClose, b.onError)
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) Send(ctx context.Context, cmd Command) error {
|
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
|
||||||
queue, ok := b.queues.Get(cmd.Queue)
|
|
||||||
if !ok || queue == nil {
|
|
||||||
return errors.New("invalid queue or not exists")
|
|
||||||
}
|
|
||||||
queue.send(ctx, cmd)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) TLSConfig() TLSConfig {
|
|
||||||
return b.opts.tlsConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) SyncMode() bool {
|
|
||||||
return b.opts.syncMode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) sendToPublisher(ctx context.Context, publisherID string, result Result) error {
|
|
||||||
pub, ok := b.publishers.Get(publisherID)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return pub.send(ctx, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) onClose(ctx context.Context, _ net.Conn) error {
|
|
||||||
consumerID, ok := GetConsumerID(ctx)
|
consumerID, ok := GetConsumerID(ctx)
|
||||||
if ok && consumerID != "" {
|
if ok && consumerID != "" {
|
||||||
if con, exists := b.consumers.Get(consumerID); exists {
|
if con, exists := b.consumers.Get(consumerID); exists {
|
||||||
@@ -157,11 +72,94 @@ func (b *Broker) onClose(ctx context.Context, _ net.Conn) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) onError(_ context.Context, conn net.Conn, err error) {
|
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
||||||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the broker server with optional TLS support
|
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||||
|
switch msg.Command {
|
||||||
|
case consts.PUBLISH:
|
||||||
|
b.PublishHandler(ctx, conn, msg)
|
||||||
|
case consts.SUBSCRIBE:
|
||||||
|
b.SubscribeHandler(ctx, conn, msg)
|
||||||
|
case consts.MESSAGE_RESPONSE:
|
||||||
|
b.MessageResponseHandler(ctx, msg)
|
||||||
|
case consts.MESSAGE_ACK:
|
||||||
|
b.MessageAck(ctx, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
|
||||||
|
consumerID, _ := GetConsumerID(ctx)
|
||||||
|
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||||
|
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
|
||||||
|
msg.Command = consts.RESPONSE
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.HandleCallback(ctx, msg)
|
||||||
|
awaitResponse, ok := headers[consts.AwaitResponseKey]
|
||||||
|
if !(ok && awaitResponse == "true") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
publisherID, exists := headers[consts.PublisherKey]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
con, ok := b.publishers.Get(publisherID)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := b.send(con.conn, msg)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) Publish(ctx context.Context, task Task, queue string) error {
|
||||||
|
headers, _ := GetHeaders(ctx)
|
||||||
|
payload, _ := json.Marshal(task)
|
||||||
|
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers)
|
||||||
|
b.broadcastToConsumers(ctx, msg)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||||
|
pub := b.addPublisher(ctx, msg.Queue, conn)
|
||||||
|
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||||
|
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
||||||
|
|
||||||
|
ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||||||
|
if err := b.send(conn, ack); err != nil {
|
||||||
|
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||||||
|
}
|
||||||
|
b.broadcastToConsumers(ctx, msg)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
b.publishers.Del(pub.id)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||||
|
consumerID := b.addConsumer(ctx, msg.Queue, conn)
|
||||||
|
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||||||
|
if err := b.send(conn, ack); err != nil {
|
||||||
|
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
b.removeConsumer(msg.Queue, consumerID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Broker) Start(ctx context.Context) error {
|
func (b *Broker) Start(ctx context.Context) error {
|
||||||
var listener net.Listener
|
var listener net.Listener
|
||||||
var err error
|
var err error
|
||||||
@@ -178,113 +176,61 @@ func (b *Broker) Start(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to start TLS listener: %v", err)
|
return fmt.Errorf("failed to start TLS listener: %v", err)
|
||||||
}
|
}
|
||||||
log.Println("TLS server started on", b.opts.brokerAddr)
|
log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr)
|
||||||
} else {
|
} else {
|
||||||
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to start TCP listener: %v", err)
|
return fmt.Errorf("failed to start TCP listener: %v", err)
|
||||||
}
|
}
|
||||||
log.Println("TCP server started on", b.opts.brokerAddr)
|
log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr)
|
||||||
}
|
}
|
||||||
defer listener.Close()
|
defer listener.Close()
|
||||||
for {
|
for {
|
||||||
conn, err := listener.Accept()
|
conn, err := listener.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("Error accepting connection:", err)
|
b.OnError(ctx, conn, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go ReadFromConn(ctx, conn, Handlers{
|
go func(c net.Conn) {
|
||||||
MessageHandler: b.opts.messageHandler,
|
defer c.Close()
|
||||||
CloseHandler: b.opts.closeHandler,
|
for {
|
||||||
ErrorHandler: b.opts.errorHandler,
|
err := b.readMessage(ctx, c)
|
||||||
})
|
if err != nil {
|
||||||
}
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) Publish(ctx context.Context, message Task, queueName string) Result {
|
|
||||||
queue, task, err := b.AddMessageToQueue(&message, queueName)
|
|
||||||
if err != nil {
|
|
||||||
return Result{Error: err}
|
|
||||||
}
|
|
||||||
result := Result{
|
|
||||||
Command: "PUBLISH",
|
|
||||||
Payload: message.Payload,
|
|
||||||
Queue: queueName,
|
|
||||||
MessageID: task.ID,
|
|
||||||
}
|
|
||||||
if queue.consumers.Size() == 0 {
|
|
||||||
queue.deferred.Set(NewID(), &message)
|
|
||||||
fmt.Println("task deferred as no consumers are connected", queueName)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
queue.send(ctx, message)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) NewQueue(qName string) *Queue {
|
|
||||||
q, ok := b.queues.Get(qName)
|
|
||||||
if ok {
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
q = newQueue(qName)
|
|
||||||
b.queues.Set(qName, q)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) AddMessageToQueue(task *Task, queueName string) (*Queue, *Task, error) {
|
|
||||||
queue := b.NewQueue(queueName)
|
|
||||||
if task.ID == "" {
|
|
||||||
task.ID = NewID()
|
|
||||||
}
|
|
||||||
if queueName != "" {
|
|
||||||
task.CurrentQueue = queueName
|
|
||||||
}
|
|
||||||
task.CreatedAt = time.Now()
|
|
||||||
queue.messages.Set(task.ID, task)
|
|
||||||
return queue, task, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) HandleProcessedMessage(ctx context.Context, result Result) error {
|
|
||||||
publisherID, ok := GetPublisherID(ctx)
|
|
||||||
if ok && publisherID != "" {
|
|
||||||
err := b.sendToPublisher(ctx, publisherID, result)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, callback := range b.opts.callback {
|
|
||||||
if callback != nil {
|
|
||||||
rs := callback(ctx, result)
|
|
||||||
if rs.Error != nil {
|
|
||||||
return rs.Error
|
|
||||||
}
|
}
|
||||||
}
|
}(conn)
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
|
||||||
consumerID, ok := GetConsumerID(ctx)
|
return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||||
defer func() {
|
|
||||||
cmd := Command{
|
|
||||||
Command: consts.SUBSCRIBE_ACK,
|
|
||||||
Queue: queueName,
|
|
||||||
Error: "",
|
|
||||||
}
|
|
||||||
Write(ctx, conn, cmd)
|
|
||||||
log.Printf("Consumer %s joined server on queue %s", consumerID, queueName)
|
|
||||||
}()
|
|
||||||
q, ok := b.queues.Get(queueName)
|
|
||||||
if !ok {
|
|
||||||
q = b.NewQueue(queueName)
|
|
||||||
}
|
|
||||||
con := &consumer{id: consumerID, conn: conn}
|
|
||||||
b.consumers.Set(consumerID, con)
|
|
||||||
q.consumers.Set(consumerID, con)
|
|
||||||
return consumerID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) string {
|
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
|
||||||
|
return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
||||||
|
if queue, ok := b.queues.Get(msg.Queue); ok {
|
||||||
|
task := &QueuedTask{Message: msg, RetryCount: 0}
|
||||||
|
queue.tasks <- task
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
|
||||||
|
msg, err := b.receive(conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.MESSAGE_ACK {
|
||||||
|
log.Println("Received CONSUMER_ACK: Subscribed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
|
||||||
publisherID, ok := GetPublisherID(ctx)
|
publisherID, ok := GetPublisherID(ctx)
|
||||||
_, ok = b.queues.Get(queueName)
|
_, ok = b.queues.Get(queueName)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -292,20 +238,22 @@ func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Co
|
|||||||
}
|
}
|
||||||
con := &publisher{id: publisherID, conn: conn}
|
con := &publisher{id: publisherID, conn: conn}
|
||||||
b.publishers.Set(publisherID, con)
|
b.publishers.Set(publisherID, con)
|
||||||
return publisherID
|
return con
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) subscribe(ctx context.Context, queueName string, conn net.Conn) {
|
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||||||
consumerID := b.addConsumer(ctx, queueName, conn)
|
consumerID, ok := GetConsumerID(ctx)
|
||||||
go func() {
|
q, ok := b.queues.Get(queueName)
|
||||||
select {
|
if !ok {
|
||||||
case <-ctx.Done():
|
q = b.NewQueue(queueName)
|
||||||
b.removeConsumer(queueName, consumerID)
|
}
|
||||||
}
|
con := &consumer{id: consumerID, conn: conn}
|
||||||
}()
|
b.consumers.Set(consumerID, con)
|
||||||
|
q.consumers.Set(consumerID, con)
|
||||||
|
log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName)
|
||||||
|
return consumerID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Removes connection from the queue and broker
|
|
||||||
func (b *Broker) removeConsumer(queueName, consumerID string) {
|
func (b *Broker) removeConsumer(queueName, consumerID string) {
|
||||||
if queue, ok := b.queues.Get(queueName); ok {
|
if queue, ok := b.queues.Get(queueName); ok {
|
||||||
con, ok := queue.consumers.Get(consumerID)
|
con, ok := queue.consumers.Get(consumerID)
|
||||||
@@ -317,57 +265,59 @@ func (b *Broker) removeConsumer(queueName, consumerID string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) onMessage(ctx context.Context, conn net.Conn, message []byte) error {
|
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||||
var cmdMsg Command
|
msg, err := b.receive(c)
|
||||||
var resultMsg Result
|
|
||||||
err := json.Unmarshal(message, &cmdMsg)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return b.handleCommandMessage(ctx, conn, cmdMsg)
|
ctx = SetHeaders(ctx, msg.Headers)
|
||||||
}
|
b.OnMessage(ctx, msg, c)
|
||||||
err = json.Unmarshal(message, &resultMsg)
|
|
||||||
if err == nil {
|
|
||||||
return b.handleTaskMessage(ctx, conn, resultMsg)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) handleTaskMessage(ctx context.Context, _ net.Conn, msg Result) error {
|
|
||||||
return b.HandleProcessedMessage(ctx, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error {
|
|
||||||
status := "PUBLISH"
|
|
||||||
if msg.Command == consts.REQUEST {
|
|
||||||
status = "REQUEST"
|
|
||||||
}
|
|
||||||
b.addPublisher(ctx, msg.Queue, conn)
|
|
||||||
task := Task{
|
|
||||||
ID: msg.MessageID,
|
|
||||||
Payload: msg.Payload,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
CurrentQueue: msg.Queue,
|
|
||||||
}
|
|
||||||
result := b.Publish(ctx, task, msg.Queue)
|
|
||||||
if result.Error != nil {
|
|
||||||
return result.Error
|
|
||||||
}
|
|
||||||
if task.ID != "" {
|
|
||||||
result.Status = status
|
|
||||||
result.MessageID = task.ID
|
|
||||||
result.Queue = msg.Queue
|
|
||||||
return Write(ctx, conn, result)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) handleCommandMessage(ctx context.Context, conn net.Conn, msg Command) error {
|
|
||||||
switch msg.Command {
|
|
||||||
case consts.SUBSCRIBE:
|
|
||||||
b.subscribe(ctx, msg.Queue, conn)
|
|
||||||
return nil
|
return nil
|
||||||
case consts.PUBLISH, consts.REQUEST:
|
}
|
||||||
return b.publish(ctx, conn, msg)
|
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||||
default:
|
b.OnClose(ctx, c)
|
||||||
return fmt.Errorf("unknown command: %d", msg.Command)
|
return err
|
||||||
|
}
|
||||||
|
b.OnError(ctx, c, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) dispatchWorker(queue *Queue) {
|
||||||
|
delay := b.opts.initialDelay
|
||||||
|
for task := range queue.tasks {
|
||||||
|
success := false
|
||||||
|
for !success && task.RetryCount <= b.opts.maxRetries {
|
||||||
|
if b.dispatchTaskToConsumer(queue, task) {
|
||||||
|
success = true
|
||||||
|
} else {
|
||||||
|
task.RetryCount++
|
||||||
|
delay = b.backoffRetry(queue, task, delay)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
||||||
|
var consumerFound bool
|
||||||
|
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
||||||
|
if err := b.send(con.conn, task.Message); err == nil {
|
||||||
|
consumerFound = true
|
||||||
|
return false // break the loop once a consumer is found
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if !consumerFound {
|
||||||
|
log.Printf("No available consumers for queue %s, retrying...", queue.name)
|
||||||
|
}
|
||||||
|
return consumerFound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
|
||||||
|
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
|
||||||
|
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
|
||||||
|
time.Sleep(backoffDuration)
|
||||||
|
queue.tasks <- task
|
||||||
|
delay *= 2
|
||||||
|
if delay > b.opts.maxBackoff {
|
||||||
|
delay = b.opts.maxBackoff
|
||||||
|
}
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
195
consumer.go
195
consumer.go
@@ -7,10 +7,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/codec"
|
||||||
"github.com/oarkflow/mq/consts"
|
"github.com/oarkflow/mq/consts"
|
||||||
|
"github.com/oarkflow/mq/jsonparser"
|
||||||
"github.com/oarkflow/mq/utils"
|
"github.com/oarkflow/mq/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,16 +28,20 @@ type Consumer struct {
|
|||||||
|
|
||||||
// NewConsumer initializes a new consumer with the provided options.
|
// NewConsumer initializes a new consumer with the provided options.
|
||||||
func NewConsumer(id string, opts ...Option) *Consumer {
|
func NewConsumer(id string, opts ...Option) *Consumer {
|
||||||
options := defaultOptions()
|
options := setupOptions(opts...)
|
||||||
for _, opt := range opts {
|
return &Consumer{
|
||||||
opt(&options)
|
|
||||||
}
|
|
||||||
b := &Consumer{
|
|
||||||
handlers: make(map[string]Handler),
|
handlers: make(map[string]Handler),
|
||||||
id: id,
|
id: id,
|
||||||
|
opts: options,
|
||||||
}
|
}
|
||||||
b.opts = defaultHandlers(options, b.onMessage, b.onClose, b.onError)
|
}
|
||||||
return b
|
|
||||||
|
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
|
||||||
|
return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
||||||
|
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the consumer's connection.
|
// Close closes the consumer's connection.
|
||||||
@@ -43,90 +50,82 @@ func (c *Consumer) Close() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe to a specific queue.
|
// Subscribe to a specific queue.
|
||||||
func (c *Consumer) subscribe(queue string) error {
|
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||||
ctx := context.Background()
|
headers := WithHeaders(ctx, map[string]string{
|
||||||
ctx = SetHeaders(ctx, map[string]string{
|
|
||||||
consts.ConsumerKey: c.id,
|
consts.ConsumerKey: c.id,
|
||||||
consts.ContentType: consts.TypeJson,
|
consts.ContentType: consts.TypeJson,
|
||||||
})
|
})
|
||||||
subscribe := Command{
|
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers)
|
||||||
Command: consts.SUBSCRIBE,
|
if err := c.send(c.conn, msg); err != nil {
|
||||||
Queue: queue,
|
return err
|
||||||
ID: NewID(),
|
}
|
||||||
|
|
||||||
|
return c.waitForAck(c.conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error {
|
||||||
|
fmt.Println("Consumer closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
||||||
|
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||||
|
headers := WithHeaders(ctx, map[string]string{
|
||||||
|
consts.ConsumerKey: c.id,
|
||||||
|
consts.ContentType: consts.TypeJson,
|
||||||
|
})
|
||||||
|
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||||
|
reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||||
|
if err := c.send(conn, reply); err != nil {
|
||||||
|
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||||||
|
}
|
||||||
|
var task Task
|
||||||
|
err := json.Unmarshal(msg.Payload, &task)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Error unmarshalling message:", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
||||||
|
result := c.ProcessTask(ctx, task)
|
||||||
|
result.MessageID = task.ID
|
||||||
|
result.Queue = msg.Queue
|
||||||
|
if result.Error != nil {
|
||||||
|
result.Status = "FAILED"
|
||||||
|
} else {
|
||||||
|
result.Status = "SUCCESS"
|
||||||
|
}
|
||||||
|
bt, _ := json.Marshal(result)
|
||||||
|
reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers)
|
||||||
|
if err := c.send(conn, reply); err != nil {
|
||||||
|
fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err)
|
||||||
}
|
}
|
||||||
return Write(ctx, c.conn, subscribe)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessTask handles a received task message and invokes the appropriate handler.
|
// ProcessTask handles a received task message and invokes the appropriate handler.
|
||||||
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
||||||
handler, exists := c.handlers[msg.CurrentQueue]
|
queue, _ := GetQueue(ctx)
|
||||||
|
handler, exists := c.handlers[queue]
|
||||||
if !exists {
|
if !exists {
|
||||||
return Result{Error: errors.New("No handler for queue " + msg.CurrentQueue)}
|
return Result{Error: errors.New("No handler for queue " + queue)}
|
||||||
}
|
}
|
||||||
return handler(ctx, msg)
|
return handler(ctx, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle command message sent by the server.
|
|
||||||
func (c *Consumer) handleCommandMessage(msg Command) error {
|
|
||||||
switch msg.Command {
|
|
||||||
case consts.STOP:
|
|
||||||
return c.Close()
|
|
||||||
case consts.SUBSCRIBE_ACK:
|
|
||||||
log.Printf("Consumer %s subscribed to queue %s\n", c.id, msg.Queue)
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unknown command in consumer %d", msg.Command)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle task message sent by the server.
|
|
||||||
func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error {
|
|
||||||
response := c.ProcessTask(ctx, msg)
|
|
||||||
response.Queue = msg.CurrentQueue
|
|
||||||
if msg.ID == "" {
|
|
||||||
response.Error = errors.New("task ID is empty")
|
|
||||||
response.Command = "error"
|
|
||||||
} else {
|
|
||||||
response.Command = "completed"
|
|
||||||
response.MessageID = msg.ID
|
|
||||||
}
|
|
||||||
return c.sendResult(ctx, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the result of task processing back to the server.
|
|
||||||
func (c *Consumer) sendResult(ctx context.Context, response Result) error {
|
|
||||||
return Write(ctx, c.conn, response)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read and handle incoming messages.
|
|
||||||
func (c *Consumer) readMessage(ctx context.Context, message []byte) error {
|
|
||||||
var cmdMsg Command
|
|
||||||
var task Task
|
|
||||||
err := json.Unmarshal(message, &cmdMsg)
|
|
||||||
if err == nil && cmdMsg.Command != 0 {
|
|
||||||
return c.handleCommandMessage(cmdMsg)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(message, &task)
|
|
||||||
if err == nil {
|
|
||||||
return c.handleTaskMessage(ctx, task)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
|
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
|
||||||
func (c *Consumer) AttemptConnect() error {
|
func (c *Consumer) AttemptConnect() error {
|
||||||
var err error
|
var err error
|
||||||
delay := c.opts.initialDelay
|
delay := c.opts.initialDelay
|
||||||
|
|
||||||
for i := 0; i < c.opts.maxRetries; i++ {
|
for i := 0; i < c.opts.maxRetries; i++ {
|
||||||
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
||||||
fmt.Printf("Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
||||||
time.Sleep(sleepDuration)
|
time.Sleep(sleepDuration)
|
||||||
delay *= 2
|
delay *= 2
|
||||||
if delay > c.opts.maxBackoff {
|
if delay > c.opts.maxBackoff {
|
||||||
@@ -137,20 +136,19 @@ func (c *Consumer) AttemptConnect() error {
|
|||||||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onMessage reads incoming messages from the connection.
|
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||||||
func (c *Consumer) onMessage(ctx context.Context, conn net.Conn, message []byte) error {
|
msg, err := c.receive(conn)
|
||||||
return c.readMessage(ctx, message)
|
if err == nil {
|
||||||
}
|
ctx = SetHeaders(ctx, msg.Headers)
|
||||||
|
c.OnMessage(ctx, msg, conn)
|
||||||
// onClose handles connection close event.
|
return nil
|
||||||
func (c *Consumer) onClose(ctx context.Context, conn net.Conn) error {
|
}
|
||||||
fmt.Println("Consumer Connection closed", c.id, conn.RemoteAddr())
|
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||||
return nil
|
c.OnClose(ctx, conn)
|
||||||
}
|
return err
|
||||||
|
}
|
||||||
// onError handles errors while reading from the connection.
|
c.OnError(ctx, conn, err)
|
||||||
func (c *Consumer) onError(ctx context.Context, conn net.Conn, err error) {
|
return err
|
||||||
fmt.Println("Error reading from consumer connection:", err, conn.RemoteAddr())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consume starts the consumer to consume tasks from the queues.
|
// Consume starts the consumer to consume tasks from the queues.
|
||||||
@@ -159,26 +157,39 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
for _, q := range c.queues {
|
||||||
|
if err := c.subscribe(ctx, q); err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
ReadFromConn(ctx, c.conn, Handlers{
|
for {
|
||||||
MessageHandler: c.opts.messageHandler,
|
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||||
CloseHandler: c.opts.closeHandler,
|
log.Println("Error reading message:", err)
|
||||||
ErrorHandler: c.opts.errorHandler,
|
break
|
||||||
})
|
}
|
||||||
fmt.Println("Stopping consumer")
|
|
||||||
}()
|
|
||||||
for _, q := range c.queues {
|
|
||||||
if err := c.subscribe(q); err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Consumer) waitForAck(conn net.Conn) error {
|
||||||
|
msg, err := c.receive(conn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.SUBSCRIBE_ACK {
|
||||||
|
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterHandler registers a handler for a queue.
|
// RegisterHandler registers a handler for a queue.
|
||||||
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
||||||
c.queues = append(c.queues, queue)
|
c.queues = append(c.queues, queue)
|
||||||
|
118
ctx.go
118
ctx.go
@@ -1,36 +1,31 @@
|
|||||||
package mq
|
package mq
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/xid"
|
"github.com/oarkflow/xid"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/codec"
|
|
||||||
"github.com/oarkflow/mq/consts"
|
"github.com/oarkflow/mq/consts"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MessageHandler func(context.Context, net.Conn, []byte) error
|
type Task struct {
|
||||||
|
ID string `json:"id"`
|
||||||
type CloseHandler func(context.Context, net.Conn) error
|
Payload json.RawMessage `json:"payload"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
type ErrorHandler func(context.Context, net.Conn, error)
|
ProcessedAt time.Time `json:"processed_at"`
|
||||||
|
Status string `json:"status"`
|
||||||
type Handlers struct {
|
Error error `json:"error"`
|
||||||
MessageHandler MessageHandler
|
|
||||||
CloseHandler CloseHandler
|
|
||||||
ErrorHandler ErrorHandler
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Handler func(context.Context, Task) Result
|
||||||
|
|
||||||
func IsClosed(conn net.Conn) bool {
|
func IsClosed(conn net.Conn) bool {
|
||||||
_, err := conn.Read(make([]byte, 1))
|
_, err := conn.Read(make([]byte, 1))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -52,11 +47,31 @@ func SetHeaders(ctx context.Context, headers map[string]string) context.Context
|
|||||||
return context.WithValue(ctx, consts.HeaderKey, hd)
|
return context.WithValue(ctx, consts.HeaderKey, hd)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
|
||||||
|
hd, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
hd = make(map[string]string)
|
||||||
|
}
|
||||||
|
for key, val := range headers {
|
||||||
|
hd[key] = val
|
||||||
|
}
|
||||||
|
return hd
|
||||||
|
}
|
||||||
|
|
||||||
func GetHeaders(ctx context.Context) (map[string]string, bool) {
|
func GetHeaders(ctx context.Context) (map[string]string, bool) {
|
||||||
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||||
return headers, ok
|
return headers, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetHeader(ctx context.Context, key string) (string, bool) {
|
||||||
|
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
val, ok := headers[key]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
func GetContentType(ctx context.Context) (string, bool) {
|
func GetContentType(ctx context.Context) (string, bool) {
|
||||||
headers, ok := GetHeaders(ctx)
|
headers, ok := GetHeaders(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -66,6 +81,15 @@ func GetContentType(ctx context.Context) (string, bool) {
|
|||||||
return contentType, ok
|
return contentType, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetQueue(ctx context.Context) (string, bool) {
|
||||||
|
headers, ok := GetHeaders(ctx)
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
contentType, ok := headers[consts.QueueKey]
|
||||||
|
return contentType, ok
|
||||||
|
}
|
||||||
|
|
||||||
func GetConsumerID(ctx context.Context) (string, bool) {
|
func GetConsumerID(ctx context.Context) (string, bool) {
|
||||||
headers, ok := GetHeaders(ctx)
|
headers, ok := GetHeaders(ctx)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -93,70 +117,6 @@ func GetPublisherID(ctx context.Context) (string, bool) {
|
|||||||
return contentType, ok
|
return contentType, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
func Write(ctx context.Context, conn net.Conn, data any) error {
|
|
||||||
msg := codec.Message{Headers: make(map[string]string)}
|
|
||||||
if headers, ok := GetHeaders(ctx); ok {
|
|
||||||
msg.Headers = headers
|
|
||||||
}
|
|
||||||
dataBytes, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg.Payload = dataBytes
|
|
||||||
messageBytes, err := json.Marshal(msg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = conn.Write(append(messageBytes, '\n'))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadFromConn(ctx context.Context, conn net.Conn, handlers Handlers) {
|
|
||||||
defer func() {
|
|
||||||
if handlers.CloseHandler != nil {
|
|
||||||
if err := handlers.CloseHandler(ctx, conn); err != nil {
|
|
||||||
fmt.Println("Error in close handler:", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.Close()
|
|
||||||
}()
|
|
||||||
reader := bufio.NewReader(conn)
|
|
||||||
for {
|
|
||||||
messageBytes, err := reader.ReadBytes('\n')
|
|
||||||
if err != nil {
|
|
||||||
if err == io.EOF || IsClosed(conn) || strings.Contains(err.Error(), "closed network connection") {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handlers.ErrorHandler != nil {
|
|
||||||
handlers.ErrorHandler(ctx, conn, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
messageBytes = bytes.TrimSpace(messageBytes)
|
|
||||||
if len(messageBytes) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var msg codec.Message
|
|
||||||
err = json.Unmarshal(messageBytes, &msg)
|
|
||||||
if err != nil {
|
|
||||||
if handlers.ErrorHandler != nil {
|
|
||||||
handlers.ErrorHandler(ctx, conn, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ctx = SetHeaders(ctx, msg.Headers)
|
|
||||||
if handlers.MessageHandler != nil {
|
|
||||||
err = handlers.MessageHandler(ctx, conn, msg.Payload)
|
|
||||||
if err != nil {
|
|
||||||
if handlers.ErrorHandler != nil {
|
|
||||||
handlers.ErrorHandler(ctx, conn, err)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewID() string {
|
func NewID() string {
|
||||||
return xid.New().String()
|
return xid.New().String()
|
||||||
}
|
}
|
||||||
|
137
dag/dag.go
137
dag/dag.go
@@ -3,13 +3,13 @@ package dag
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"github.com/oarkflow/mq/consts"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
"github.com/oarkflow/mq/consts"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type taskContext struct {
|
type taskContext struct {
|
||||||
@@ -76,12 +76,20 @@ func (d *DAG) Start(ctx context.Context, addr string) error {
|
|||||||
if d.server.SyncMode() {
|
if d.server.SyncMode() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
for _, con := range d.nodes {
|
|
||||||
go con.Consume(ctx)
|
|
||||||
}
|
|
||||||
go func() {
|
go func() {
|
||||||
d.server.Start(ctx)
|
err := d.server.Start(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
for _, con := range d.nodes {
|
||||||
|
go func(con *mq.Consumer) {
|
||||||
|
err := con.Consume(ctx)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}(con)
|
||||||
|
}
|
||||||
log.Printf("HTTP server started on %s", addr)
|
log.Printf("HTTP server started on %s", addr)
|
||||||
config := d.server.TLSConfig()
|
config := d.server.TLSConfig()
|
||||||
if config.UseTLS {
|
if config.UseTLS {
|
||||||
@@ -90,16 +98,6 @@ func (d *DAG) Start(ctx context.Context, addr string) error {
|
|||||||
return http.ListenAndServe(addr, nil)
|
return http.ListenAndServe(addr, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DAG) PublishTask(ctx context.Context, payload []byte, queueName string, taskID ...string) mq.Result {
|
|
||||||
task := mq.Task{
|
|
||||||
Payload: payload,
|
|
||||||
}
|
|
||||||
if len(taskID) > 0 {
|
|
||||||
task.ID = taskID[0]
|
|
||||||
}
|
|
||||||
return d.server.Publish(ctx, task, queueName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DAG) FindFirstNode() (string, bool) {
|
func (d *DAG) FindFirstNode() (string, bool) {
|
||||||
inDegree := make(map[string]int)
|
inDegree := make(map[string]int)
|
||||||
for n, _ := range d.nodes {
|
for n, _ := range d.nodes {
|
||||||
@@ -121,86 +119,23 @@ func (d *DAG) FindFirstNode() (string, bool) {
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DAG) Request(ctx context.Context, payload []byte) mq.Result {
|
func (d *DAG) PublishTask(ctx context.Context, payload json.RawMessage, taskID ...string) error {
|
||||||
return d.sendSync(ctx, mq.Result{Payload: payload})
|
queue, ok := mq.GetQueue(ctx)
|
||||||
}
|
if !ok {
|
||||||
|
queue = d.FirstNode
|
||||||
func (d *DAG) Send(ctx context.Context, payload []byte) mq.Result {
|
|
||||||
if d.FirstNode == "" {
|
|
||||||
return mq.Result{Error: fmt.Errorf("initial node not defined")}
|
|
||||||
}
|
}
|
||||||
if d.server.SyncMode() {
|
var id string
|
||||||
return d.sendSync(ctx, mq.Result{Payload: payload})
|
if len(taskID) > 0 {
|
||||||
|
id = taskID[0]
|
||||||
|
} else {
|
||||||
|
id = mq.NewID()
|
||||||
}
|
}
|
||||||
resultCh := make(chan mq.Result)
|
task := mq.Task{
|
||||||
result := d.PublishTask(ctx, payload, d.FirstNode)
|
ID: id,
|
||||||
if result.Error != nil {
|
Payload: payload,
|
||||||
return result
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
d.mu.Lock()
|
return d.server.Publish(ctx, task, queue)
|
||||||
d.taskChMap[result.MessageID] = resultCh
|
|
||||||
d.mu.Unlock()
|
|
||||||
finalResult := <-resultCh
|
|
||||||
return finalResult
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DAG) processNode(ctx context.Context, task mq.Result) mq.Result {
|
|
||||||
if con, ok := d.nodes[task.Queue]; ok {
|
|
||||||
return con.ProcessTask(ctx, mq.Task{
|
|
||||||
ID: task.MessageID,
|
|
||||||
Payload: task.Payload,
|
|
||||||
CurrentQueue: task.Queue,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return mq.Result{Error: fmt.Errorf("no consumer to process %s", task.Queue)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DAG) sendSync(ctx context.Context, task mq.Result) mq.Result {
|
|
||||||
if task.MessageID == "" {
|
|
||||||
task.MessageID = mq.NewID()
|
|
||||||
}
|
|
||||||
if task.Queue == "" {
|
|
||||||
task.Queue = d.FirstNode
|
|
||||||
}
|
|
||||||
result := d.processNode(ctx, task)
|
|
||||||
if result.Error != nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
for _, target := range d.loopEdges[task.Queue] {
|
|
||||||
var items, results []json.RawMessage
|
|
||||||
if err := json.Unmarshal(result.Payload, &items); err != nil {
|
|
||||||
return mq.Result{Error: err}
|
|
||||||
}
|
|
||||||
for _, item := range items {
|
|
||||||
result = d.sendSync(ctx, mq.Result{
|
|
||||||
Command: result.Command,
|
|
||||||
Payload: item,
|
|
||||||
Queue: target,
|
|
||||||
MessageID: result.MessageID,
|
|
||||||
})
|
|
||||||
if result.Error != nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
results = append(results, result.Payload)
|
|
||||||
}
|
|
||||||
bt, err := json.Marshal(results)
|
|
||||||
if err != nil {
|
|
||||||
return mq.Result{Error: err}
|
|
||||||
}
|
|
||||||
result.Payload = bt
|
|
||||||
}
|
|
||||||
if target, ok := d.edges[task.Queue]; ok {
|
|
||||||
result = d.sendSync(ctx, mq.Result{
|
|
||||||
Command: result.Command,
|
|
||||||
Payload: result.Payload,
|
|
||||||
Queue: target,
|
|
||||||
MessageID: result.MessageID,
|
|
||||||
})
|
|
||||||
if result.Error != nil {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) {
|
func (d *DAG) getCompletedResults(task mq.Result, ok bool, triggeredNode string) ([]byte, bool, bool) {
|
||||||
@@ -264,9 +199,12 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
|||||||
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue})
|
ctx = mq.SetHeaders(ctx, map[string]string{consts.TriggerNode: task.Queue})
|
||||||
for _, loopNode := range loopNodes {
|
for _, loopNode := range loopNodes {
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
rs := d.PublishTask(ctx, item, loopNode, task.MessageID)
|
ctx = mq.SetHeaders(ctx, map[string]string{
|
||||||
if rs.Error != nil {
|
consts.QueueKey: loopNode,
|
||||||
return rs
|
})
|
||||||
|
err := d.PublishTask(ctx, item, task.MessageID)
|
||||||
|
if err != nil {
|
||||||
|
return mq.Result{Error: err}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -284,15 +222,14 @@ func (d *DAG) TaskCallback(ctx context.Context, task mq.Result) mq.Result {
|
|||||||
totalItems: 1,
|
totalItems: 1,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
rs := d.PublishTask(ctx, payload, edge, task.MessageID)
|
err := d.PublishTask(ctx, payload, edge, task.MessageID)
|
||||||
if rs.Error != nil {
|
if err != nil {
|
||||||
return rs
|
return mq.Result{Error: err}
|
||||||
}
|
}
|
||||||
} else if completed {
|
} else if completed {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
if resultCh, ok := d.taskChMap[task.MessageID]; ok {
|
if resultCh, ok := d.taskChMap[task.MessageID]; ok {
|
||||||
resultCh <- mq.Result{
|
resultCh <- mq.Result{
|
||||||
Command: "complete",
|
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
Queue: task.Queue,
|
Queue: task.Queue,
|
||||||
MessageID: task.MessageID,
|
MessageID: task.MessageID,
|
||||||
|
@@ -2,9 +2,9 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/oarkflow/mq"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/examples/tasks"
|
"github.com/oarkflow/mq/examples/tasks"
|
||||||
mq "github.com/oarkflow/mq/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
@@ -2,19 +2,15 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
|
||||||
"github.com/oarkflow/mq/dag"
|
"github.com/oarkflow/mq/dag"
|
||||||
"github.com/oarkflow/mq/examples/tasks"
|
"github.com/oarkflow/mq/examples/tasks"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var d *dag.DAG
|
var d *dag.DAG
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
d = dag.New(mq.WithTLS(true, "server.crt", "server.key"), mq.WithCAPath("ca.crt"))
|
d = dag.New()
|
||||||
d.AddNode("queue1", tasks.Node1)
|
d.AddNode("queue1", tasks.Node1)
|
||||||
d.AddNode("queue2", tasks.Node2)
|
d.AddNode("queue2", tasks.Node2)
|
||||||
d.AddNode("queue3", tasks.Node3)
|
d.AddNode("queue3", tasks.Node3)
|
||||||
@@ -24,45 +20,14 @@ func main() {
|
|||||||
d.AddLoop("queue2", "queue3")
|
d.AddLoop("queue2", "queue3")
|
||||||
d.AddEdge("queue2", "queue4")
|
d.AddEdge("queue2", "queue4")
|
||||||
d.Prepare()
|
d.Prepare()
|
||||||
http.HandleFunc("POST /publish", requestHandler("publish"))
|
go func() {
|
||||||
http.HandleFunc("POST /request", requestHandler("request"))
|
d.Start(context.Background(), ":8081")
|
||||||
err := d.Start(context.TODO(), ":8083")
|
}()
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
err := d.PublishTask(context.Background(), []byte(`{"tast": 123}`))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func requestHandler(requestType string) func(w http.ResponseWriter, r *http.Request) {
|
time.Sleep(10 * time.Second)
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != http.MethodPost {
|
|
||||||
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var payload []byte
|
|
||||||
if r.Body != nil {
|
|
||||||
defer r.Body.Close()
|
|
||||||
var err error
|
|
||||||
payload, err = io.ReadAll(r.Body)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Failed to read request body", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
http.Error(w, "Empty request body", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var rs mq.Result
|
|
||||||
if requestType == "request" {
|
|
||||||
rs = d.Request(context.Background(), payload)
|
|
||||||
} else {
|
|
||||||
rs = d.Send(context.Background(), payload)
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
result := map[string]any{
|
|
||||||
"message_id": rs.MessageID,
|
|
||||||
"payload": string(rs.Payload),
|
|
||||||
"error": rs.Error,
|
|
||||||
}
|
|
||||||
json.NewEncoder(w).Encode(result)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@@ -3,17 +3,16 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
mq2 "github.com/oarkflow/mq"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
mq "github.com/oarkflow/mq/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
payload := []byte(`{"message":"Message Publisher \n Task"}`)
|
payload := []byte(`{"message":"Message Publisher \n Task"}`)
|
||||||
task := mq.Task{
|
task := mq2.Task{
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
publisher := mq.NewPublisher("publish-1")
|
publisher := mq2.NewPublisher("publish-1")
|
||||||
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
||||||
err := publisher.Publish(context.Background(), task, "queue1")
|
err := publisher.Publish(context.Background(), task, "queue1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -21,7 +20,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
fmt.Println("Async task published successfully")
|
fmt.Println("Async task published successfully")
|
||||||
payload = []byte(`{"message":"Fire-and-Forget \n Task"}`)
|
payload = []byte(`{"message":"Fire-and-Forget \n Task"}`)
|
||||||
task = mq.Task{
|
task = mq2.Task{
|
||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
|
@@ -2,13 +2,13 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
mq2 "github.com/oarkflow/mq"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/examples/tasks"
|
"github.com/oarkflow/mq/examples/tasks"
|
||||||
mq "github.com/oarkflow/mq/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
b := mq.NewBroker(mq.WithCallback(tasks.Callback))
|
b := mq2.NewBroker(mq2.WithCallback(tasks.Callback))
|
||||||
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
||||||
b.NewQueue("queue1")
|
b.NewQueue("queue1")
|
||||||
b.NewQueue("queue2")
|
b.NewQueue("queue2")
|
||||||
|
@@ -4,42 +4,41 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
mq2 "github.com/oarkflow/mq"
|
||||||
mq "github.com/oarkflow/mq/v2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Node1(ctx context.Context, task mq.Task) mq.Result {
|
func Node1(ctx context.Context, task mq2.Task) mq2.Result {
|
||||||
fmt.Println("Processing queue1", task.ID)
|
fmt.Println("Processing queue1", task.ID)
|
||||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
return mq2.Result{Payload: task.Payload, MessageID: task.ID}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Node2(ctx context.Context, task mq.Task) mq.Result {
|
func Node2(ctx context.Context, task mq2.Task) mq2.Result {
|
||||||
return mq.Result{Payload: task.Payload, MessageID: task.ID}
|
return mq2.Result{Payload: task.Payload, MessageID: task.ID}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Node3(ctx context.Context, task mq.Task) mq.Result {
|
func Node3(ctx context.Context, task mq2.Task) mq2.Result {
|
||||||
var data map[string]any
|
var data map[string]any
|
||||||
err := json.Unmarshal(task.Payload, &data)
|
err := json.Unmarshal(task.Payload, &data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mq.Result{Error: err}
|
return mq2.Result{Error: err}
|
||||||
}
|
}
|
||||||
data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
|
data["salary"] = fmt.Sprintf("12000%v", data["user_id"])
|
||||||
bt, _ := json.Marshal(data)
|
bt, _ := json.Marshal(data)
|
||||||
return mq.Result{Payload: bt, MessageID: task.ID}
|
return mq2.Result{Payload: bt, MessageID: task.ID}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Node4(ctx context.Context, task mq.Task) mq.Result {
|
func Node4(ctx context.Context, task mq2.Task) mq2.Result {
|
||||||
var data []map[string]any
|
var data []map[string]any
|
||||||
err := json.Unmarshal(task.Payload, &data)
|
err := json.Unmarshal(task.Payload, &data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return mq.Result{Error: err}
|
return mq2.Result{Error: err}
|
||||||
}
|
}
|
||||||
payload := map[string]any{"storage": data}
|
payload := map[string]any{"storage": data}
|
||||||
bt, _ := json.Marshal(payload)
|
bt, _ := json.Marshal(payload)
|
||||||
return mq.Result{Payload: bt, MessageID: task.ID}
|
return mq2.Result{Payload: bt, MessageID: task.ID}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Callback(ctx context.Context, task mq.Result) mq.Result {
|
func Callback(ctx context.Context, task mq2.Result) mq2.Result {
|
||||||
fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue)
|
fmt.Println("Received task", task.MessageID, "Payload", string(task.Payload), task.Error, task.Queue)
|
||||||
return mq.Result{}
|
return mq2.Result{}
|
||||||
}
|
}
|
||||||
|
85
options.go
85
options.go
@@ -2,9 +2,18 @@ package mq
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Payload json.RawMessage `json:"payload"`
|
||||||
|
Queue string `json:"queue"`
|
||||||
|
MessageID string `json:"message_id"`
|
||||||
|
Error error `json:"error,omitempty"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
type TLSConfig struct {
|
type TLSConfig struct {
|
||||||
UseTLS bool
|
UseTLS bool
|
||||||
CertPath string
|
CertPath string
|
||||||
@@ -13,17 +22,18 @@ type TLSConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
syncMode bool
|
syncMode bool
|
||||||
brokerAddr string
|
brokerAddr string
|
||||||
messageHandler MessageHandler
|
callback []func(context.Context, Result) Result
|
||||||
closeHandler CloseHandler
|
maxRetries int
|
||||||
errorHandler ErrorHandler
|
initialDelay time.Duration
|
||||||
callback []func(context.Context, Result) Result
|
maxBackoff time.Duration
|
||||||
maxRetries int
|
jitterPercent float64
|
||||||
initialDelay time.Duration
|
tlsConfig TLSConfig
|
||||||
maxBackoff time.Duration
|
aesKey json.RawMessage
|
||||||
jitterPercent float64
|
hmacKey json.RawMessage
|
||||||
tlsConfig TLSConfig
|
enableEncryption bool
|
||||||
|
queueSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultOptions() Options {
|
func defaultOptions() Options {
|
||||||
@@ -34,27 +44,29 @@ func defaultOptions() Options {
|
|||||||
initialDelay: 2 * time.Second,
|
initialDelay: 2 * time.Second,
|
||||||
maxBackoff: 20 * time.Second,
|
maxBackoff: 20 * time.Second,
|
||||||
jitterPercent: 0.5,
|
jitterPercent: 0.5,
|
||||||
|
queueSize: 100,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultHandlers(options Options, onMessage MessageHandler, onClose CloseHandler, onError ErrorHandler) Options {
|
|
||||||
if options.messageHandler == nil {
|
|
||||||
options.messageHandler = onMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.closeHandler == nil {
|
|
||||||
options.closeHandler = onClose
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.errorHandler == nil {
|
|
||||||
options.errorHandler = onError
|
|
||||||
}
|
|
||||||
return options
|
|
||||||
}
|
|
||||||
|
|
||||||
// Option defines a function type for setting options.
|
// Option defines a function type for setting options.
|
||||||
type Option func(*Options)
|
type Option func(*Options)
|
||||||
|
|
||||||
|
func setupOptions(opts ...Option) Options {
|
||||||
|
options := defaultOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&options)
|
||||||
|
}
|
||||||
|
return options
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
|
||||||
|
return func(opts *Options) {
|
||||||
|
opts.aesKey = aesKey
|
||||||
|
opts.hmacKey = hmacKey
|
||||||
|
opts.enableEncryption = enableEncryption
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithBrokerURL -
|
// WithBrokerURL -
|
||||||
func WithBrokerURL(url string) Option {
|
func WithBrokerURL(url string) Option {
|
||||||
return func(opts *Options) {
|
return func(opts *Options) {
|
||||||
@@ -119,24 +131,3 @@ func WithJitterPercent(val float64) Option {
|
|||||||
opts.jitterPercent = val
|
opts.jitterPercent = val
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMessageHandler sets a custom MessageHandler.
|
|
||||||
func WithMessageHandler(handler MessageHandler) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.messageHandler = handler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithErrorHandler sets a custom ErrorHandler.
|
|
||||||
func WithErrorHandler(handler ErrorHandler) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.errorHandler = handler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCloseHandler sets a custom CloseHandler.
|
|
||||||
func WithCloseHandler(handler CloseHandler) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.closeHandler = handler
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
100
publisher.go
100
publisher.go
@@ -4,9 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/oarkflow/mq/codec"
|
||||||
"github.com/oarkflow/mq/consts"
|
"github.com/oarkflow/mq/consts"
|
||||||
|
"github.com/oarkflow/mq/jsonparser"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Publisher struct {
|
type Publisher struct {
|
||||||
@@ -15,31 +19,59 @@ type Publisher struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewPublisher(id string, opts ...Option) *Publisher {
|
func NewPublisher(id string, opts ...Option) *Publisher {
|
||||||
options := defaultOptions()
|
options := setupOptions(opts...)
|
||||||
for _, opt := range opts {
|
return &Publisher{id: id, opts: options}
|
||||||
opt(&options)
|
|
||||||
}
|
|
||||||
b := &Publisher{id: id}
|
|
||||||
b.opts = defaultHandlers(options, nil, b.onClose, b.onError)
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
||||||
ctx = SetHeaders(ctx, map[string]string{
|
headers := WithHeaders(ctx, map[string]string{
|
||||||
consts.PublisherKey: p.id,
|
consts.PublisherKey: p.id,
|
||||||
consts.ContentType: consts.TypeJson,
|
consts.ContentType: consts.TypeJson,
|
||||||
})
|
})
|
||||||
cmd := Command{
|
if task.ID == "" {
|
||||||
ID: NewID(),
|
task.ID = NewID()
|
||||||
Command: command,
|
|
||||||
Queue: queue,
|
|
||||||
MessageID: task.ID,
|
|
||||||
Payload: task.Payload,
|
|
||||||
}
|
}
|
||||||
return Write(ctx, conn, cmd)
|
task.CreatedAt = time.Now()
|
||||||
|
payload, err := json.Marshal(task)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
msg := codec.NewMessage(command, payload, queue, headers)
|
||||||
|
if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.waitForAck(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error {
|
func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||||
|
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if msg.Command == consts.PUBLISH_ACK {
|
||||||
|
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||||
|
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) waitForResponse(conn net.Conn) Result {
|
||||||
|
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
||||||
|
if err != nil {
|
||||||
|
return Result{Error: err}
|
||||||
|
}
|
||||||
|
if msg.Command == consts.RESPONSE {
|
||||||
|
var result Result
|
||||||
|
err = json.Unmarshal(msg.Payload, &result)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
|
||||||
|
return Result{Error: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
|
||||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to connect to broker: %w", err)
|
return fmt.Errorf("failed to connect to broker: %w", err)
|
||||||
@@ -57,30 +89,22 @@ func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
|
|||||||
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Result, error) {
|
func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result {
|
||||||
|
ctx = SetHeaders(ctx, map[string]string{
|
||||||
|
consts.AwaitResponseKey: "true",
|
||||||
|
})
|
||||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Result{Error: err}, fmt.Errorf("failed to connect to broker: %w", err)
|
err = fmt.Errorf("failed to connect to broker: %w", err)
|
||||||
|
return Result{Error: err}
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
var result Result
|
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||||
err = p.send(ctx, queue, task, conn, consts.REQUEST)
|
resultCh := make(chan Result)
|
||||||
if err != nil {
|
go func() {
|
||||||
return result, err
|
defer close(resultCh)
|
||||||
}
|
resultCh <- p.waitForResponse(conn)
|
||||||
if p.opts.messageHandler == nil {
|
}()
|
||||||
p.opts.messageHandler = func(ctx context.Context, conn net.Conn, message []byte) error {
|
finalResult := <-resultCh
|
||||||
err := json.Unmarshal(message, &result)
|
return finalResult
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return conn.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ReadFromConn(ctx, conn, Handlers{
|
|
||||||
MessageHandler: p.opts.messageHandler,
|
|
||||||
CloseHandler: p.opts.closeHandler,
|
|
||||||
ErrorHandler: p.opts.errorHandler,
|
|
||||||
})
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
package v2
|
package mq
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/oarkflow/xsync"
|
"github.com/oarkflow/xsync"
|
@@ -1,4 +1,4 @@
|
|||||||
package v2
|
package mq
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
323
v2/broker.go
323
v2/broker.go
@@ -1,323 +0,0 @@
|
|||||||
package v2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/oarkflow/xsync"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq/codec"
|
|
||||||
"github.com/oarkflow/mq/consts"
|
|
||||||
"github.com/oarkflow/mq/jsonparser"
|
|
||||||
"github.com/oarkflow/mq/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type QueuedTask struct {
|
|
||||||
Message *codec.Message
|
|
||||||
RetryCount int
|
|
||||||
}
|
|
||||||
|
|
||||||
type consumer struct {
|
|
||||||
id string
|
|
||||||
conn net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
type publisher struct {
|
|
||||||
id string
|
|
||||||
conn net.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
type Broker struct {
|
|
||||||
queues xsync.IMap[string, *Queue]
|
|
||||||
consumers xsync.IMap[string, *consumer]
|
|
||||||
publishers xsync.IMap[string, *publisher]
|
|
||||||
opts Options
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewBroker(opts ...Option) *Broker {
|
|
||||||
options := setupOptions(opts...)
|
|
||||||
return &Broker{
|
|
||||||
queues: xsync.NewMap[string, *Queue](),
|
|
||||||
publishers: xsync.NewMap[string, *publisher](),
|
|
||||||
consumers: xsync.NewMap[string, *consumer](),
|
|
||||||
opts: options,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
|
|
||||||
consumerID, ok := GetConsumerID(ctx)
|
|
||||||
if ok && consumerID != "" {
|
|
||||||
if con, exists := b.consumers.Get(consumerID); exists {
|
|
||||||
con.conn.Close()
|
|
||||||
b.consumers.Del(consumerID)
|
|
||||||
}
|
|
||||||
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
|
||||||
queue.consumers.Del(consumerID)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
publisherID, ok := GetPublisherID(ctx)
|
|
||||||
if ok && publisherID != "" {
|
|
||||||
if con, exists := b.publishers.Get(publisherID); exists {
|
|
||||||
con.conn.Close()
|
|
||||||
b.publishers.Del(publisherID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
|
||||||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
|
||||||
switch msg.Command {
|
|
||||||
case consts.PUBLISH:
|
|
||||||
b.PublishHandler(ctx, conn, msg)
|
|
||||||
case consts.SUBSCRIBE:
|
|
||||||
b.SubscribeHandler(ctx, conn, msg)
|
|
||||||
case consts.MESSAGE_RESPONSE:
|
|
||||||
b.MessageResponseHandler(ctx, msg)
|
|
||||||
case consts.MESSAGE_ACK:
|
|
||||||
b.MessageAck(ctx, msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
|
|
||||||
consumerID, _ := GetConsumerID(ctx)
|
|
||||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
|
||||||
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
|
|
||||||
msg.Command = consts.RESPONSE
|
|
||||||
headers, ok := GetHeaders(ctx)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
b.HandleCallback(ctx, msg)
|
|
||||||
awaitResponse, ok := headers[consts.AwaitResponseKey]
|
|
||||||
if !(ok && awaitResponse == "true") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
publisherID, exists := headers[consts.PublisherKey]
|
|
||||||
if !exists {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
con, ok := b.publishers.Get(publisherID)
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err := b.send(con.conn, msg)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) Publish(ctx context.Context, task Task, queue string) error {
|
|
||||||
headers, _ := GetHeaders(ctx)
|
|
||||||
msg := codec.NewMessage(consts.PUBLISH, task.Payload, queue, headers)
|
|
||||||
b.broadcastToConsumers(ctx, msg)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
|
||||||
pub := b.addPublisher(ctx, msg.Queue, conn)
|
|
||||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
|
||||||
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
|
||||||
|
|
||||||
ack := codec.NewMessage(consts.PUBLISH_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
|
||||||
if err := b.send(conn, ack); err != nil {
|
|
||||||
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
|
||||||
}
|
|
||||||
b.broadcastToConsumers(ctx, msg)
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
b.publishers.Del(pub.id)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
|
||||||
consumerID := b.addConsumer(ctx, msg.Queue, conn)
|
|
||||||
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
|
||||||
if err := b.send(conn, ack); err != nil {
|
|
||||||
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
b.removeConsumer(msg.Queue, consumerID)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) Start(ctx context.Context) error {
|
|
||||||
var listener net.Listener
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if b.opts.tlsConfig.UseTLS {
|
|
||||||
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to load TLS certificates: %v", err)
|
|
||||||
}
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
}
|
|
||||||
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to start TLS listener: %v", err)
|
|
||||||
}
|
|
||||||
log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr)
|
|
||||||
} else {
|
|
||||||
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to start TCP listener: %v", err)
|
|
||||||
}
|
|
||||||
log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
for {
|
|
||||||
conn, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
b.OnError(ctx, conn, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go func(c net.Conn) {
|
|
||||||
defer c.Close()
|
|
||||||
for {
|
|
||||||
err := b.readMessage(ctx, c)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(conn)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
|
|
||||||
return codec.SendMessage(conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
|
|
||||||
return codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
|
||||||
if queue, ok := b.queues.Get(msg.Queue); ok {
|
|
||||||
task := &QueuedTask{Message: msg, RetryCount: 0}
|
|
||||||
queue.tasks <- task
|
|
||||||
log.Printf("Task enqueued for queue %s", msg.Queue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
|
|
||||||
msg, err := b.receive(conn)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Command == consts.MESSAGE_ACK {
|
|
||||||
log.Println("Received CONSUMER_ACK: Subscribed successfully")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
|
|
||||||
publisherID, ok := GetPublisherID(ctx)
|
|
||||||
_, ok = b.queues.Get(queueName)
|
|
||||||
if !ok {
|
|
||||||
b.NewQueue(queueName)
|
|
||||||
}
|
|
||||||
con := &publisher{id: publisherID, conn: conn}
|
|
||||||
b.publishers.Set(publisherID, con)
|
|
||||||
return con
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
|
||||||
consumerID, ok := GetConsumerID(ctx)
|
|
||||||
q, ok := b.queues.Get(queueName)
|
|
||||||
if !ok {
|
|
||||||
q = b.NewQueue(queueName)
|
|
||||||
}
|
|
||||||
con := &consumer{id: consumerID, conn: conn}
|
|
||||||
b.consumers.Set(consumerID, con)
|
|
||||||
q.consumers.Set(consumerID, con)
|
|
||||||
log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName)
|
|
||||||
return consumerID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) removeConsumer(queueName, consumerID string) {
|
|
||||||
if queue, ok := b.queues.Get(queueName); ok {
|
|
||||||
con, ok := queue.consumers.Get(consumerID)
|
|
||||||
if ok {
|
|
||||||
con.conn.Close()
|
|
||||||
queue.consumers.Del(consumerID)
|
|
||||||
}
|
|
||||||
b.queues.Del(queueName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
|
||||||
msg, err := b.receive(c)
|
|
||||||
if err == nil {
|
|
||||||
ctx = SetHeaders(ctx, msg.Headers)
|
|
||||||
b.OnMessage(ctx, msg, c)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
|
||||||
b.OnClose(ctx, c)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
b.OnError(ctx, c, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) dispatchWorker(queue *Queue) {
|
|
||||||
delay := b.opts.initialDelay
|
|
||||||
for task := range queue.tasks {
|
|
||||||
success := false
|
|
||||||
for !success && task.RetryCount <= b.opts.maxRetries {
|
|
||||||
if b.dispatchTaskToConsumer(queue, task) {
|
|
||||||
success = true
|
|
||||||
} else {
|
|
||||||
task.RetryCount++
|
|
||||||
delay = b.backoffRetry(queue, task, delay)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
|
||||||
var consumerFound bool
|
|
||||||
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
|
||||||
if err := b.send(con.conn, task.Message); err == nil {
|
|
||||||
consumerFound = true
|
|
||||||
log.Printf("Task dispatched to consumer %s on queue %s", con.id, queue.name)
|
|
||||||
return false // break the loop once a consumer is found
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
if !consumerFound {
|
|
||||||
log.Printf("No available consumers for queue %s, retrying...", queue.name)
|
|
||||||
}
|
|
||||||
return consumerFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
|
|
||||||
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
|
|
||||||
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
|
|
||||||
time.Sleep(backoffDuration)
|
|
||||||
queue.tasks <- task
|
|
||||||
delay *= 2
|
|
||||||
if delay > b.opts.maxBackoff {
|
|
||||||
delay = b.opts.maxBackoff
|
|
||||||
}
|
|
||||||
return delay
|
|
||||||
}
|
|
197
v2/consumer.go
197
v2/consumer.go
@@ -1,197 +0,0 @@
|
|||||||
package v2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq/codec"
|
|
||||||
"github.com/oarkflow/mq/consts"
|
|
||||||
"github.com/oarkflow/mq/jsonparser"
|
|
||||||
"github.com/oarkflow/mq/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Consumer structure to hold consumer-specific configurations and state.
|
|
||||||
type Consumer struct {
|
|
||||||
id string
|
|
||||||
handlers map[string]Handler
|
|
||||||
conn net.Conn
|
|
||||||
queues []string
|
|
||||||
opts Options
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewConsumer initializes a new consumer with the provided options.
|
|
||||||
func NewConsumer(id string, opts ...Option) *Consumer {
|
|
||||||
options := setupOptions(opts...)
|
|
||||||
return &Consumer{
|
|
||||||
handlers: make(map[string]Handler),
|
|
||||||
id: id,
|
|
||||||
opts: options,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
|
|
||||||
return codec.SendMessage(conn, msg, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
|
||||||
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the consumer's connection.
|
|
||||||
func (c *Consumer) Close() error {
|
|
||||||
return c.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Subscribe to a specific queue.
|
|
||||||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
|
||||||
headers := WithHeaders(ctx, map[string]string{
|
|
||||||
consts.ConsumerKey: c.id,
|
|
||||||
consts.ContentType: consts.TypeJson,
|
|
||||||
})
|
|
||||||
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers)
|
|
||||||
if err := c.send(c.conn, msg); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.waitForAck(c.conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) OnClose(ctx context.Context, _ net.Conn) error {
|
|
||||||
fmt.Println("Consumer closed")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
|
||||||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
|
||||||
headers := WithHeaders(ctx, map[string]string{
|
|
||||||
consts.ConsumerKey: c.id,
|
|
||||||
consts.ContentType: consts.TypeJson,
|
|
||||||
})
|
|
||||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
|
||||||
reply := codec.NewMessage(consts.MESSAGE_ACK, []byte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
|
||||||
if err := c.send(conn, reply); err != nil {
|
|
||||||
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
|
||||||
}
|
|
||||||
var task Task
|
|
||||||
err := json.Unmarshal(msg.Payload, &task)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Error unmarshalling message:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
|
||||||
result := c.ProcessTask(ctx, task)
|
|
||||||
result.MessageID = task.ID
|
|
||||||
result.Queue = msg.Queue
|
|
||||||
if result.Error != nil {
|
|
||||||
result.Status = "FAILED"
|
|
||||||
} else {
|
|
||||||
result.Status = "SUCCESS"
|
|
||||||
}
|
|
||||||
bt, _ := json.Marshal(result)
|
|
||||||
reply = codec.NewMessage(consts.MESSAGE_RESPONSE, bt, msg.Queue, headers)
|
|
||||||
if err := c.send(conn, reply); err != nil {
|
|
||||||
fmt.Printf("failed to send MESSAGE_RESPONSE for queue %s: %v", msg.Queue, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ProcessTask handles a received task message and invokes the appropriate handler.
|
|
||||||
func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result {
|
|
||||||
queue, _ := GetQueue(ctx)
|
|
||||||
handler, exists := c.handlers[queue]
|
|
||||||
if !exists {
|
|
||||||
return Result{Error: errors.New("No handler for queue " + queue)}
|
|
||||||
}
|
|
||||||
return handler(ctx, msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AttemptConnect tries to establish a connection to the server, with TLS or without, based on the configuration.
|
|
||||||
func (c *Consumer) AttemptConnect() error {
|
|
||||||
var err error
|
|
||||||
delay := c.opts.initialDelay
|
|
||||||
for i := 0; i < c.opts.maxRetries; i++ {
|
|
||||||
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
|
||||||
if err == nil {
|
|
||||||
c.conn = conn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
|
||||||
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
|
||||||
time.Sleep(sleepDuration)
|
|
||||||
delay *= 2
|
|
||||||
if delay > c.opts.maxBackoff {
|
|
||||||
delay = c.opts.maxBackoff
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
|
||||||
msg, err := c.receive(conn)
|
|
||||||
if err == nil {
|
|
||||||
ctx = SetHeaders(ctx, msg.Headers)
|
|
||||||
c.OnMessage(ctx, msg, conn)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
|
||||||
c.OnClose(ctx, conn)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.OnError(ctx, conn, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Consume starts the consumer to consume tasks from the queues.
|
|
||||||
func (c *Consumer) Consume(ctx context.Context) error {
|
|
||||||
err := c.AttemptConnect()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, q := range c.queues {
|
|
||||||
if err := c.subscribe(ctx, q); err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to server for queue %s: %v", q, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for {
|
|
||||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
|
||||||
log.Println("Error reading message:", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Consumer) waitForAck(conn net.Conn) error {
|
|
||||||
msg, err := c.receive(conn)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Command == consts.SUBSCRIBE_ACK {
|
|
||||||
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RegisterHandler registers a handler for a queue.
|
|
||||||
func (c *Consumer) RegisterHandler(queue string, handler Handler) {
|
|
||||||
c.queues = append(c.queues, queue)
|
|
||||||
c.handlers[queue] = handler
|
|
||||||
}
|
|
158
v2/ctx.go
158
v2/ctx.go
@@ -1,158 +0,0 @@
|
|||||||
package v2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/oarkflow/xid"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq/consts"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Task struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Payload json.RawMessage `json:"payload"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
ProcessedAt time.Time `json:"processed_at"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
Error error `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Handler func(context.Context, Task) Result
|
|
||||||
|
|
||||||
func IsClosed(conn net.Conn) bool {
|
|
||||||
_, err := conn.Read(make([]byte, 1))
|
|
||||||
if err != nil {
|
|
||||||
if err == net.ErrClosed {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetHeaders(ctx context.Context, headers map[string]string) context.Context {
|
|
||||||
hd, ok := GetHeaders(ctx)
|
|
||||||
if !ok {
|
|
||||||
hd = make(map[string]string)
|
|
||||||
}
|
|
||||||
for key, val := range headers {
|
|
||||||
hd[key] = val
|
|
||||||
}
|
|
||||||
return context.WithValue(ctx, consts.HeaderKey, hd)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithHeaders(ctx context.Context, headers map[string]string) map[string]string {
|
|
||||||
hd, ok := GetHeaders(ctx)
|
|
||||||
if !ok {
|
|
||||||
hd = make(map[string]string)
|
|
||||||
}
|
|
||||||
for key, val := range headers {
|
|
||||||
hd[key] = val
|
|
||||||
}
|
|
||||||
return hd
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetHeaders(ctx context.Context) (map[string]string, bool) {
|
|
||||||
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
|
||||||
return headers, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetHeader(ctx context.Context, key string) (string, bool) {
|
|
||||||
headers, ok := ctx.Value(consts.HeaderKey).(map[string]string)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
val, ok := headers[key]
|
|
||||||
return val, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetContentType(ctx context.Context) (string, bool) {
|
|
||||||
headers, ok := GetHeaders(ctx)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
contentType, ok := headers[consts.ContentType]
|
|
||||||
return contentType, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetQueue(ctx context.Context) (string, bool) {
|
|
||||||
headers, ok := GetHeaders(ctx)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
contentType, ok := headers[consts.QueueKey]
|
|
||||||
return contentType, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetConsumerID(ctx context.Context) (string, bool) {
|
|
||||||
headers, ok := GetHeaders(ctx)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
contentType, ok := headers[consts.ConsumerKey]
|
|
||||||
return contentType, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetTriggerNode(ctx context.Context) (string, bool) {
|
|
||||||
headers, ok := GetHeaders(ctx)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
contentType, ok := headers[consts.TriggerNode]
|
|
||||||
return contentType, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPublisherID(ctx context.Context) (string, bool) {
|
|
||||||
headers, ok := GetHeaders(ctx)
|
|
||||||
if !ok {
|
|
||||||
return "", false
|
|
||||||
}
|
|
||||||
contentType, ok := headers[consts.PublisherKey]
|
|
||||||
return contentType, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewID() string {
|
|
||||||
return xid.New().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func createTLSConnection(addr, certPath, keyPath string, caPath ...string) (net.Conn, error) {
|
|
||||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load client cert/key: %w", err)
|
|
||||||
}
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
||||||
InsecureSkipVerify: true,
|
|
||||||
}
|
|
||||||
if len(caPath) > 0 && caPath[0] != "" {
|
|
||||||
caCert, err := os.ReadFile(caPath[0])
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load CA cert: %w", err)
|
|
||||||
}
|
|
||||||
caCertPool := x509.NewCertPool()
|
|
||||||
caCertPool.AppendCertsFromPEM(caCert)
|
|
||||||
tlsConfig.RootCAs = caCertPool
|
|
||||||
tlsConfig.ClientCAs = caCertPool
|
|
||||||
}
|
|
||||||
conn, err := tls.Dial("tcp", addr, tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to dial TLS connection: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
|
|
||||||
if config.UseTLS {
|
|
||||||
return createTLSConnection(addr, config.CertPath, config.KeyPath, config.CAPath)
|
|
||||||
} else {
|
|
||||||
return net.Dial("tcp", addr)
|
|
||||||
}
|
|
||||||
}
|
|
133
v2/options.go
133
v2/options.go
@@ -1,133 +0,0 @@
|
|||||||
package v2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Result struct {
|
|
||||||
Payload json.RawMessage `json:"payload"`
|
|
||||||
Queue string `json:"queue"`
|
|
||||||
MessageID string `json:"message_id"`
|
|
||||||
Error error `json:"error,omitempty"`
|
|
||||||
Status string `json:"status"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TLSConfig struct {
|
|
||||||
UseTLS bool
|
|
||||||
CertPath string
|
|
||||||
KeyPath string
|
|
||||||
CAPath string
|
|
||||||
}
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
syncMode bool
|
|
||||||
brokerAddr string
|
|
||||||
callback []func(context.Context, Result) Result
|
|
||||||
maxRetries int
|
|
||||||
initialDelay time.Duration
|
|
||||||
maxBackoff time.Duration
|
|
||||||
jitterPercent float64
|
|
||||||
tlsConfig TLSConfig
|
|
||||||
aesKey json.RawMessage
|
|
||||||
hmacKey json.RawMessage
|
|
||||||
enableEncryption bool
|
|
||||||
queueSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultOptions() Options {
|
|
||||||
return Options{
|
|
||||||
syncMode: false,
|
|
||||||
brokerAddr: ":8080",
|
|
||||||
maxRetries: 5,
|
|
||||||
initialDelay: 2 * time.Second,
|
|
||||||
maxBackoff: 20 * time.Second,
|
|
||||||
jitterPercent: 0.5,
|
|
||||||
queueSize: 100,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Option defines a function type for setting options.
|
|
||||||
type Option func(*Options)
|
|
||||||
|
|
||||||
func setupOptions(opts ...Option) Options {
|
|
||||||
options := defaultOptions()
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(&options)
|
|
||||||
}
|
|
||||||
return options
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.aesKey = aesKey
|
|
||||||
opts.hmacKey = hmacKey
|
|
||||||
opts.enableEncryption = enableEncryption
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithBrokerURL -
|
|
||||||
func WithBrokerURL(url string) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.brokerAddr = url
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTLS - Option to enable/disable TLS
|
|
||||||
func WithTLS(enableTLS bool, certPath, keyPath string) Option {
|
|
||||||
return func(o *Options) {
|
|
||||||
o.tlsConfig.UseTLS = enableTLS
|
|
||||||
o.tlsConfig.CertPath = certPath
|
|
||||||
o.tlsConfig.KeyPath = keyPath
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCAPath - Option to enable/disable TLS
|
|
||||||
func WithCAPath(caPath string) Option {
|
|
||||||
return func(o *Options) {
|
|
||||||
o.tlsConfig.CAPath = caPath
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithSyncMode -
|
|
||||||
func WithSyncMode(mode bool) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.syncMode = mode
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithMaxRetries -
|
|
||||||
func WithMaxRetries(val int) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.maxRetries = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithInitialDelay -
|
|
||||||
func WithInitialDelay(val time.Duration) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.initialDelay = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithMaxBackoff -
|
|
||||||
func WithMaxBackoff(val time.Duration) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.maxBackoff = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCallback -
|
|
||||||
func WithCallback(val ...func(context.Context, Result) Result) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.callback = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithJitterPercent -
|
|
||||||
func WithJitterPercent(val float64) Option {
|
|
||||||
return func(opts *Options) {
|
|
||||||
opts.jitterPercent = val
|
|
||||||
}
|
|
||||||
}
|
|
110
v2/publisher.go
110
v2/publisher.go
@@ -1,110 +0,0 @@
|
|||||||
package v2
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq/codec"
|
|
||||||
"github.com/oarkflow/mq/consts"
|
|
||||||
"github.com/oarkflow/mq/jsonparser"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Publisher struct {
|
|
||||||
id string
|
|
||||||
opts Options
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPublisher(id string, opts ...Option) *Publisher {
|
|
||||||
options := setupOptions(opts...)
|
|
||||||
return &Publisher{id: id, opts: options}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
|
||||||
headers := WithHeaders(ctx, map[string]string{
|
|
||||||
consts.PublisherKey: p.id,
|
|
||||||
consts.ContentType: consts.TypeJson,
|
|
||||||
})
|
|
||||||
if task.ID == "" {
|
|
||||||
task.ID = NewID()
|
|
||||||
}
|
|
||||||
task.CreatedAt = time.Now()
|
|
||||||
payload, err := json.Marshal(task)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
msg := codec.NewMessage(command, payload, queue, headers)
|
|
||||||
if err := codec.SendMessage(conn, msg, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.waitForAck(conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) waitForAck(conn net.Conn) error {
|
|
||||||
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if msg.Command == consts.PUBLISH_ACK {
|
|
||||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
|
||||||
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) waitForResponse(conn net.Conn) Result {
|
|
||||||
msg, err := codec.ReadMessage(conn, p.opts.aesKey, p.opts.hmacKey, p.opts.enableEncryption)
|
|
||||||
if err != nil {
|
|
||||||
return Result{Error: err}
|
|
||||||
}
|
|
||||||
if msg.Command == consts.RESPONSE {
|
|
||||||
var result Result
|
|
||||||
err = json.Unmarshal(msg.Payload, &result)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
|
|
||||||
return Result{Error: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
|
|
||||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to broker: %w", err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
return p.send(ctx, queue, task, conn, consts.PUBLISH)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) onClose(ctx context.Context, conn net.Conn) error {
|
|
||||||
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) onError(ctx context.Context, conn net.Conn, err error) {
|
|
||||||
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result {
|
|
||||||
ctx = SetHeaders(ctx, map[string]string{
|
|
||||||
consts.AwaitResponseKey: "true",
|
|
||||||
})
|
|
||||||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
|
||||||
if err != nil {
|
|
||||||
err = fmt.Errorf("failed to connect to broker: %w", err)
|
|
||||||
return Result{Error: err}
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
|
||||||
resultCh := make(chan Result)
|
|
||||||
go func() {
|
|
||||||
defer close(resultCh)
|
|
||||||
resultCh <- p.waitForResponse(conn)
|
|
||||||
}()
|
|
||||||
finalResult := <-resultCh
|
|
||||||
return finalResult
|
|
||||||
}
|
|
Reference in New Issue
Block a user