mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-05 07:57:00 +08:00
1322 lines
34 KiB
Go
1322 lines
34 KiB
Go
package mq
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/oarkflow/errors"
|
||
"github.com/oarkflow/json"
|
||
|
||
"github.com/oarkflow/mq/codec"
|
||
"github.com/oarkflow/mq/consts"
|
||
"github.com/oarkflow/mq/jsonparser"
|
||
"github.com/oarkflow/mq/logger"
|
||
"github.com/oarkflow/mq/storage"
|
||
"github.com/oarkflow/mq/storage/memory"
|
||
"github.com/oarkflow/mq/utils"
|
||
)
|
||
|
||
type Status string
|
||
|
||
const (
|
||
Pending Status = "Pending"
|
||
Processing Status = "Processing"
|
||
Completed Status = "Completed"
|
||
Failed Status = "Failed"
|
||
)
|
||
|
||
type Result struct {
|
||
CreatedAt time.Time `json:"created_at"`
|
||
ProcessedAt time.Time `json:"processed_at,omitempty"`
|
||
Latency string `json:"latency"`
|
||
Error error `json:"-"` // Keep error as an error type
|
||
Topic string `json:"topic"`
|
||
TaskID string `json:"task_id"`
|
||
Status Status `json:"status"`
|
||
ConditionStatus string `json:"condition_status"`
|
||
Ctx context.Context `json:"-"`
|
||
Payload json.RawMessage `json:"payload"`
|
||
Last bool
|
||
}
|
||
|
||
func (r Result) MarshalJSON() ([]byte, error) {
|
||
type Alias Result
|
||
aux := &struct {
|
||
ErrorMsg string `json:"error,omitempty"`
|
||
Alias
|
||
}{
|
||
Alias: (Alias)(r),
|
||
}
|
||
if r.Error != nil {
|
||
aux.ErrorMsg = r.Error.Error()
|
||
}
|
||
return json.Marshal(aux)
|
||
}
|
||
|
||
func (r *Result) UnmarshalJSON(data []byte) error {
|
||
type Alias Result
|
||
aux := &struct {
|
||
*Alias
|
||
ErrMsg string `json:"error,omitempty"`
|
||
}{
|
||
Alias: (*Alias)(r),
|
||
}
|
||
|
||
if err := json.Unmarshal(data, &aux); err != nil {
|
||
return err
|
||
}
|
||
if aux.ErrMsg != "" {
|
||
r.Error = errors.New(aux.ErrMsg)
|
||
} else {
|
||
r.Error = nil
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (r Result) Unmarshal(data any) error {
|
||
if r.Payload == nil {
|
||
return fmt.Errorf("payload is nil")
|
||
}
|
||
return json.Unmarshal(r.Payload, data)
|
||
}
|
||
|
||
func HandleError(ctx context.Context, err error, status ...Status) Result {
|
||
st := Failed
|
||
if len(status) > 0 {
|
||
st = status[0]
|
||
}
|
||
if err == nil {
|
||
return Result{Ctx: ctx}
|
||
}
|
||
return Result{
|
||
Ctx: ctx,
|
||
Status: st,
|
||
Error: err,
|
||
}
|
||
}
|
||
|
||
func (r Result) WithData(status Status, data []byte) Result {
|
||
if r.Error != nil {
|
||
return r
|
||
}
|
||
return Result{
|
||
Status: status,
|
||
Payload: data,
|
||
Ctx: r.Ctx,
|
||
}
|
||
}
|
||
|
||
type TLSConfig struct {
|
||
CertPath string
|
||
KeyPath string
|
||
CAPath string
|
||
UseTLS bool
|
||
}
|
||
|
||
// NEW: RateLimiter implementation
|
||
type RateLimiter struct {
|
||
C chan struct{}
|
||
}
|
||
|
||
// Modified RateLimiter: use blocking send to avoid discarding tokens.
|
||
func NewRateLimiter(rate int, burst int) *RateLimiter {
|
||
rl := &RateLimiter{C: make(chan struct{}, burst)}
|
||
ticker := time.NewTicker(time.Second / time.Duration(rate))
|
||
go func() {
|
||
for range ticker.C {
|
||
rl.C <- struct{}{} // blocking send; tokens queue for deferred task processing
|
||
}
|
||
}()
|
||
return rl
|
||
}
|
||
|
||
func (rl *RateLimiter) Wait() {
|
||
<-rl.C
|
||
}
|
||
|
||
type Options struct {
|
||
storage TaskStorage
|
||
consumerOnSubscribe func(ctx context.Context, topic, consumerName string)
|
||
consumerOnClose func(ctx context.Context, topic, consumerName string)
|
||
notifyResponse func(context.Context, Result) error
|
||
brokerAddr string
|
||
tlsConfig TLSConfig
|
||
callback []func(context.Context, Result) Result
|
||
queueSize int
|
||
initialDelay time.Duration
|
||
maxBackoff time.Duration
|
||
jitterPercent float64
|
||
maxRetries int
|
||
numOfWorkers int
|
||
maxMemoryLoad int64
|
||
syncMode bool
|
||
cleanTaskOnComplete bool
|
||
enableWorkerPool bool
|
||
respondPendingResult bool
|
||
logger logger.Logger
|
||
BrokerRateLimiter *RateLimiter // new field for broker rate limiting
|
||
ConsumerRateLimiter *RateLimiter // new field for consumer rate limiting
|
||
}
|
||
|
||
func (o *Options) SetSyncMode(sync bool) {
|
||
o.syncMode = sync
|
||
}
|
||
|
||
func (o *Options) NumOfWorkers() int {
|
||
return o.numOfWorkers
|
||
}
|
||
|
||
func (o *Options) Logger() logger.Logger {
|
||
return o.logger
|
||
}
|
||
|
||
func (o *Options) Storage() TaskStorage {
|
||
return o.storage
|
||
}
|
||
|
||
func (o *Options) CleanTaskOnComplete() bool {
|
||
return o.cleanTaskOnComplete
|
||
}
|
||
|
||
func (o *Options) QueueSize() int {
|
||
return o.queueSize
|
||
}
|
||
|
||
func (o *Options) MaxMemoryLoad() int64 {
|
||
return o.maxMemoryLoad
|
||
}
|
||
|
||
func (o *Options) BrokerAddr() string {
|
||
return o.brokerAddr
|
||
}
|
||
|
||
func HeadersWithConsumerID(ctx context.Context, id string) map[string]string {
|
||
return WithHeaders(ctx, map[string]string{consts.ConsumerKey: id, consts.ContentType: consts.TypeJson})
|
||
}
|
||
|
||
func HeadersWithConsumerIDAndQueue(ctx context.Context, id, queue string) map[string]string {
|
||
return WithHeaders(ctx, map[string]string{
|
||
consts.ConsumerKey: id,
|
||
consts.ContentType: consts.TypeJson,
|
||
consts.QueueKey: queue,
|
||
})
|
||
}
|
||
|
||
type QueuedTask struct {
|
||
Message *codec.Message
|
||
RetryCount int
|
||
}
|
||
|
||
type consumer struct {
|
||
conn net.Conn
|
||
id string
|
||
state consts.ConsumerState
|
||
}
|
||
|
||
type publisher struct {
|
||
conn net.Conn
|
||
id string
|
||
}
|
||
|
||
type Broker struct {
|
||
queues storage.IMap[string, *Queue]
|
||
consumers storage.IMap[string, *consumer]
|
||
publishers storage.IMap[string, *publisher]
|
||
deadLetter storage.IMap[string, *Queue]
|
||
opts *Options
|
||
listener net.Listener
|
||
}
|
||
|
||
func NewBroker(opts ...Option) *Broker {
|
||
options := SetupOptions(opts...)
|
||
return &Broker{
|
||
queues: memory.New[string, *Queue](),
|
||
publishers: memory.New[string, *publisher](),
|
||
consumers: memory.New[string, *consumer](),
|
||
deadLetter: memory.New[string, *Queue](),
|
||
opts: options,
|
||
}
|
||
}
|
||
|
||
func (b *Broker) Options() *Options {
|
||
return b.opts
|
||
}
|
||
|
||
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
|
||
consumerID, ok := GetConsumerID(ctx)
|
||
if ok && consumerID != "" {
|
||
log.Printf("Broker: Consumer connection closed: %s, address: %s", consumerID, conn.RemoteAddr())
|
||
if con, exists := b.consumers.Get(consumerID); exists {
|
||
con.conn.Close()
|
||
b.consumers.Del(consumerID)
|
||
}
|
||
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
||
if _, ok := queue.consumers.Get(consumerID); ok {
|
||
if b.opts.consumerOnClose != nil {
|
||
b.opts.consumerOnClose(ctx, queue.name, consumerID)
|
||
}
|
||
queue.consumers.Del(consumerID)
|
||
}
|
||
return true
|
||
})
|
||
} else {
|
||
b.consumers.ForEach(func(consumerID string, con *consumer) bool {
|
||
if utils.ConnectionsEqual(conn, con.conn) {
|
||
log.Printf("Broker: Consumer connection closed: %s, address: %s", consumerID, conn.RemoteAddr())
|
||
con.conn.Close()
|
||
b.consumers.Del(consumerID)
|
||
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
||
queue.consumers.Del(consumerID)
|
||
if _, ok := queue.consumers.Get(consumerID); ok {
|
||
if b.opts.consumerOnClose != nil {
|
||
b.opts.consumerOnClose(ctx, queue.name, consumerID)
|
||
}
|
||
}
|
||
return true
|
||
})
|
||
}
|
||
return true
|
||
})
|
||
}
|
||
|
||
publisherID, ok := GetPublisherID(ctx)
|
||
if ok && publisherID != "" {
|
||
log.Printf("Broker: Publisher connection closed: %s, address: %s", publisherID, conn.RemoteAddr())
|
||
if con, exists := b.publishers.Get(publisherID); exists {
|
||
con.conn.Close()
|
||
b.publishers.Del(publisherID)
|
||
}
|
||
}
|
||
log.Printf("BROKER - Connection closed: address %s", conn.RemoteAddr())
|
||
return nil
|
||
}
|
||
|
||
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
||
if conn != nil {
|
||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||
}
|
||
}
|
||
|
||
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||
switch msg.Command {
|
||
case consts.PUBLISH:
|
||
b.PublishHandler(ctx, conn, msg)
|
||
case consts.SUBSCRIBE:
|
||
b.SubscribeHandler(ctx, conn, msg)
|
||
case consts.MESSAGE_RESPONSE:
|
||
b.MessageResponseHandler(ctx, msg)
|
||
case consts.MESSAGE_ACK:
|
||
b.MessageAck(ctx, msg)
|
||
case consts.MESSAGE_DENY:
|
||
b.MessageDeny(ctx, msg)
|
||
case consts.CONSUMER_PAUSED:
|
||
b.OnConsumerPause(ctx, msg)
|
||
case consts.CONSUMER_RESUMED:
|
||
b.OnConsumerResume(ctx, msg)
|
||
case consts.CONSUMER_STOPPED:
|
||
b.OnConsumerStop(ctx, msg)
|
||
default:
|
||
log.Printf("BROKER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
|
||
}
|
||
}
|
||
|
||
func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
|
||
consumerID, _ := GetConsumerID(ctx)
|
||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
|
||
}
|
||
|
||
func (b *Broker) MessageDeny(ctx context.Context, msg *codec.Message) {
|
||
consumerID, _ := GetConsumerID(ctx)
|
||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||
taskError, _ := jsonparser.GetString(msg.Payload, "error")
|
||
log.Printf("BROKER - MESSAGE_DENY ~> %s on %s for Task %s, Error: %s", consumerID, msg.Queue, taskID, taskError)
|
||
}
|
||
|
||
func (b *Broker) OnConsumerPause(ctx context.Context, _ *codec.Message) {
|
||
consumerID, _ := GetConsumerID(ctx)
|
||
if consumerID != "" {
|
||
if con, exists := b.consumers.Get(consumerID); exists {
|
||
con.state = consts.ConsumerStatePaused
|
||
log.Printf("BROKER - CONSUMER ~> Paused %s", consumerID)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (b *Broker) OnConsumerStop(ctx context.Context, _ *codec.Message) {
|
||
consumerID, _ := GetConsumerID(ctx)
|
||
if consumerID != "" {
|
||
if con, exists := b.consumers.Get(consumerID); exists {
|
||
con.state = consts.ConsumerStateStopped
|
||
log.Printf("BROKER - CONSUMER ~> Stopped %s", consumerID)
|
||
if b.opts.notifyResponse != nil {
|
||
result := Result{
|
||
Status: "STOPPED",
|
||
Topic: "", // adjust if queue name is available
|
||
TaskID: consumerID,
|
||
Ctx: ctx,
|
||
}
|
||
_ = b.opts.notifyResponse(ctx, result)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func (b *Broker) OnConsumerResume(ctx context.Context, _ *codec.Message) {
|
||
consumerID, _ := GetConsumerID(ctx)
|
||
if consumerID != "" {
|
||
if con, exists := b.consumers.Get(consumerID); exists {
|
||
con.state = consts.ConsumerStateActive
|
||
log.Printf("BROKER - CONSUMER ~> Resumed %s", consumerID)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
|
||
msg.Command = consts.RESPONSE
|
||
b.HandleCallback(ctx, msg)
|
||
awaitResponse, ok := GetAwaitResponse(ctx)
|
||
if !(ok && awaitResponse == "true") {
|
||
return
|
||
}
|
||
publisherID, exists := GetPublisherID(ctx)
|
||
if !exists {
|
||
return
|
||
}
|
||
con, ok := b.publishers.Get(publisherID)
|
||
if !ok {
|
||
return
|
||
}
|
||
err := b.send(ctx, con.conn, msg)
|
||
if err != nil {
|
||
panic(err)
|
||
}
|
||
}
|
||
|
||
func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error {
|
||
headers, _ := GetHeaders(ctx)
|
||
payload, err := json.Marshal(task)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
|
||
b.broadcastToConsumers(msg)
|
||
return nil
|
||
}
|
||
|
||
func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||
pub := b.addPublisher(ctx, msg.Queue, conn)
|
||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
||
|
||
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||
if err := b.send(ctx, conn, ack); err != nil {
|
||
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||
}
|
||
b.broadcastToConsumers(msg)
|
||
go func() {
|
||
select {
|
||
case <-ctx.Done():
|
||
b.publishers.Del(pub.id)
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
|
||
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||
if err := b.send(ctx, conn, ack); err != nil {
|
||
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||
}
|
||
if b.opts.consumerOnSubscribe != nil {
|
||
b.opts.consumerOnSubscribe(ctx, msg.Queue, consumerID)
|
||
}
|
||
go func() {
|
||
select {
|
||
case <-ctx.Done():
|
||
b.RemoveConsumer(consumerID, msg.Queue)
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (b *Broker) Start(ctx context.Context) error {
|
||
var listener net.Listener
|
||
var err error
|
||
if b.opts.tlsConfig.UseTLS {
|
||
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to load TLS certificates: %v", err)
|
||
}
|
||
tlsConfig := &tls.Config{
|
||
Certificates: []tls.Certificate{cert},
|
||
}
|
||
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to start TLS listener: %v", err)
|
||
}
|
||
log.Println("BROKER - RUNNING_TLS ~> started on", b.opts.brokerAddr)
|
||
} else {
|
||
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to start TCP listener: %v", err)
|
||
}
|
||
log.Println("BROKER - RUNNING ~> started on", b.opts.brokerAddr)
|
||
}
|
||
b.listener = listener
|
||
defer b.Close()
|
||
const maxConcurrentConnections = 100
|
||
sem := make(chan struct{}, maxConcurrentConnections)
|
||
for {
|
||
conn, err := listener.Accept()
|
||
if err != nil {
|
||
b.OnError(ctx, conn, err)
|
||
time.Sleep(50 * time.Millisecond)
|
||
continue
|
||
}
|
||
sem <- struct{}{}
|
||
go func(c net.Conn) {
|
||
defer func() {
|
||
<-sem
|
||
c.Close()
|
||
}()
|
||
for {
|
||
err := b.readMessage(ctx, c)
|
||
if err != nil {
|
||
if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
|
||
log.Println("Temporary network error, retrying:", netErr)
|
||
continue
|
||
}
|
||
log.Println("Connection closed due to error:", err)
|
||
break
|
||
}
|
||
}
|
||
}(conn)
|
||
}
|
||
}
|
||
|
||
func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||
return codec.SendMessage(ctx, conn, msg)
|
||
}
|
||
|
||
func (b *Broker) receive(ctx context.Context, c net.Conn) (*codec.Message, error) {
|
||
return codec.ReadMessage(ctx, c)
|
||
}
|
||
|
||
func (b *Broker) broadcastToConsumers(msg *codec.Message) {
|
||
if queue, ok := b.queues.Get(msg.Queue); ok {
|
||
task := &QueuedTask{Message: msg, RetryCount: 0}
|
||
queue.tasks <- task
|
||
}
|
||
}
|
||
|
||
func (b *Broker) waitForConsumerAck(ctx context.Context, conn net.Conn) error {
|
||
msg, err := b.receive(ctx, conn)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if msg.Command == consts.MESSAGE_ACK {
|
||
log.Println("Received CONSUMER_ACK: Subscribed successfully")
|
||
return nil
|
||
}
|
||
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
|
||
}
|
||
|
||
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
|
||
publisherID, ok := GetPublisherID(ctx)
|
||
_, ok = b.queues.Get(queueName)
|
||
if !ok {
|
||
b.NewQueue(queueName)
|
||
}
|
||
con := &publisher{id: publisherID, conn: conn}
|
||
b.publishers.Set(publisherID, con)
|
||
return con
|
||
}
|
||
|
||
func (b *Broker) AddConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||
consumerID, ok := GetConsumerID(ctx)
|
||
q, ok := b.queues.Get(queueName)
|
||
if !ok {
|
||
q = b.NewQueue(queueName)
|
||
}
|
||
con := &consumer{id: consumerID, conn: conn}
|
||
b.consumers.Set(consumerID, con)
|
||
q.consumers.Set(consumerID, con)
|
||
log.Printf("BROKER - SUBSCRIBE ~> %s on %s", consumerID, queueName)
|
||
return consumerID
|
||
}
|
||
|
||
func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
|
||
if len(queues) > 0 {
|
||
for _, queueName := range queues {
|
||
if queue, ok := b.queues.Get(queueName); ok {
|
||
con, ok := queue.consumers.Get(consumerID)
|
||
if ok {
|
||
con.conn.Close()
|
||
queue.consumers.Del(consumerID)
|
||
}
|
||
b.queues.Del(queueName)
|
||
}
|
||
}
|
||
return
|
||
}
|
||
b.queues.ForEach(func(queueName string, queue *Queue) bool {
|
||
con, ok := queue.consumers.Get(consumerID)
|
||
if ok {
|
||
con.conn.Close()
|
||
queue.consumers.Del(consumerID)
|
||
}
|
||
b.queues.Del(queueName)
|
||
return true
|
||
})
|
||
}
|
||
|
||
func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
|
||
fn := func(queue *Queue) {
|
||
con, ok := queue.consumers.Get(consumerID)
|
||
if ok {
|
||
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||
err := b.send(ctx, con.conn, ack)
|
||
if err == nil {
|
||
con.state = state
|
||
}
|
||
}
|
||
}
|
||
if len(queues) > 0 {
|
||
for _, queueName := range queues {
|
||
if queue, ok := b.queues.Get(queueName); ok {
|
||
fn(queue)
|
||
}
|
||
}
|
||
return
|
||
}
|
||
b.queues.ForEach(func(queueName string, queue *Queue) bool {
|
||
fn(queue)
|
||
return true
|
||
})
|
||
}
|
||
|
||
func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||
b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
|
||
}
|
||
|
||
func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||
b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
|
||
}
|
||
|
||
func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||
b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
|
||
}
|
||
|
||
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||
msg, err := b.receive(ctx, c)
|
||
if err == nil {
|
||
ctx = SetHeaders(ctx, msg.Headers)
|
||
b.OnMessage(ctx, msg, c)
|
||
return nil
|
||
}
|
||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||
b.OnClose(ctx, c)
|
||
return err
|
||
}
|
||
b.OnError(ctx, c, err)
|
||
return err
|
||
}
|
||
|
||
func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) {
|
||
delay := b.opts.initialDelay
|
||
for task := range queue.tasks {
|
||
if b.opts.BrokerRateLimiter != nil {
|
||
b.opts.BrokerRateLimiter.Wait()
|
||
}
|
||
success := false
|
||
for !success && task.RetryCount <= b.opts.maxRetries {
|
||
if b.dispatchTaskToConsumer(ctx, queue, task) {
|
||
success = true
|
||
} else {
|
||
task.RetryCount++
|
||
delay = b.backoffRetry(queue, task, delay)
|
||
}
|
||
}
|
||
if task.RetryCount > b.opts.maxRetries {
|
||
b.sendToDLQ(queue, task)
|
||
}
|
||
}
|
||
}
|
||
|
||
func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
|
||
id, _ := jsonparser.GetString(task.Message.Payload, "id")
|
||
if dlq, ok := b.deadLetter.Get(queue.name); ok {
|
||
log.Printf("Sending task %s to dead-letter queue for %s", id, queue.name)
|
||
dlq.tasks <- task
|
||
} else {
|
||
log.Printf("No dead-letter queue for %s, discarding task %s", queue.name, id)
|
||
}
|
||
}
|
||
|
||
func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task *QueuedTask) bool {
|
||
var consumerFound bool
|
||
var err error
|
||
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
||
if con.state != consts.ConsumerStateActive {
|
||
err = fmt.Errorf("consumer %s is not active", con.id)
|
||
return true
|
||
}
|
||
if err := b.send(ctx, con.conn, task.Message); err == nil {
|
||
consumerFound = true
|
||
return false
|
||
}
|
||
return true
|
||
})
|
||
if err != nil {
|
||
log.Println(err.Error())
|
||
return false
|
||
}
|
||
if !consumerFound {
|
||
log.Printf("No available consumers for queue %s, retrying...", queue.name)
|
||
if b.opts.notifyResponse != nil {
|
||
result := Result{
|
||
Status: "NO_CONSUMER",
|
||
Topic: queue.name,
|
||
TaskID: "",
|
||
Ctx: ctx,
|
||
}
|
||
_ = b.opts.notifyResponse(ctx, result)
|
||
}
|
||
}
|
||
return consumerFound
|
||
}
|
||
|
||
// Modified backoffRetry: Removed re‑insertion of the task into queue.tasks.
|
||
func (b *Broker) backoffRetry(queue *Queue, task *QueuedTask, delay time.Duration) time.Duration {
|
||
backoffDuration := utils.CalculateJitter(delay, b.opts.jitterPercent)
|
||
log.Printf("Backing off for %v before retrying task for queue %s", backoffDuration, task.Message.Queue)
|
||
time.Sleep(backoffDuration)
|
||
delay *= 2
|
||
if delay > b.opts.maxBackoff {
|
||
delay = b.opts.maxBackoff
|
||
}
|
||
return delay
|
||
}
|
||
|
||
func (b *Broker) URL() string {
|
||
return b.opts.brokerAddr
|
||
}
|
||
|
||
func (b *Broker) Close() error {
|
||
if b != nil && b.listener != nil {
|
||
log.Printf("Broker is closing...")
|
||
return b.listener.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (b *Broker) SetURL(url string) {
|
||
b.opts.brokerAddr = url
|
||
}
|
||
|
||
type Processor interface {
|
||
ProcessTask(ctx context.Context, msg *Task) Result
|
||
Consume(ctx context.Context) error
|
||
Pause(ctx context.Context) error
|
||
Resume(ctx context.Context) error
|
||
Stop(ctx context.Context) error
|
||
Close() error
|
||
GetKey() string
|
||
SetKey(key string)
|
||
GetType() string
|
||
}
|
||
|
||
type Consumer struct {
|
||
conn net.Conn
|
||
handler Handler
|
||
pool *Pool
|
||
opts *Options
|
||
id string
|
||
queue string
|
||
}
|
||
|
||
func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Consumer {
|
||
options := SetupOptions(opts...)
|
||
return &Consumer{
|
||
id: id,
|
||
opts: options,
|
||
queue: queue,
|
||
handler: handler,
|
||
}
|
||
}
|
||
|
||
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||
return codec.SendMessage(ctx, conn, msg)
|
||
}
|
||
|
||
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
|
||
return codec.ReadMessage(ctx, conn)
|
||
}
|
||
|
||
func (c *Consumer) Close() error {
|
||
c.pool.Stop()
|
||
err := c.conn.Close()
|
||
log.Printf("CONSUMER - Connection closed for consumer: %s", c.id)
|
||
return err
|
||
}
|
||
|
||
func (c *Consumer) GetKey() string {
|
||
return c.id
|
||
}
|
||
|
||
func (c *Consumer) GetType() string {
|
||
return "consumer"
|
||
}
|
||
|
||
func (c *Consumer) SetKey(key string) {
|
||
c.id = key
|
||
}
|
||
|
||
func (c *Consumer) Metrics() Metrics {
|
||
return c.pool.Metrics()
|
||
}
|
||
|
||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||
headers := HeadersWithConsumerID(ctx, c.id)
|
||
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
||
if err := c.send(ctx, c.conn, msg); err != nil {
|
||
return fmt.Errorf("error while trying to subscribe: %v", err)
|
||
}
|
||
return c.waitForAck(ctx, c.conn)
|
||
}
|
||
|
||
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
|
||
fmt.Println("Consumer closed")
|
||
return nil
|
||
}
|
||
|
||
func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
||
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
||
}
|
||
|
||
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) error {
|
||
switch msg.Command {
|
||
case consts.PUBLISH:
|
||
c.ConsumeMessage(ctx, msg, conn)
|
||
case consts.CONSUMER_PAUSE:
|
||
err := c.Pause(ctx)
|
||
if err != nil {
|
||
log.Printf("Unable to pause consumer: %v", err)
|
||
}
|
||
return err
|
||
case consts.CONSUMER_RESUME:
|
||
err := c.Resume(ctx)
|
||
if err != nil {
|
||
log.Printf("Unable to resume consumer: %v", err)
|
||
}
|
||
return err
|
||
case consts.CONSUMER_STOP:
|
||
err := c.Stop(ctx)
|
||
if err != nil {
|
||
log.Printf("Unable to stop consumer: %v", err)
|
||
}
|
||
return err
|
||
default:
|
||
log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
|
||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||
if err := c.send(ctx, conn, reply); err != nil {
|
||
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||
}
|
||
}
|
||
|
||
func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||
c.sendMessageAck(ctx, msg, conn)
|
||
if msg.Payload == nil {
|
||
log.Printf("Received empty message payload")
|
||
return
|
||
}
|
||
var task Task
|
||
err := json.Unmarshal(msg.Payload, &task)
|
||
if err != nil {
|
||
log.Printf("Error unmarshalling message: %v", err)
|
||
return
|
||
}
|
||
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
||
if err := c.pool.EnqueueTask(ctx, &task, 1); err != nil {
|
||
c.sendDenyMessage(ctx, task.ID, msg.Queue, err)
|
||
return
|
||
}
|
||
}
|
||
|
||
func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result {
|
||
defer RecoverPanic(RecoverTitle)
|
||
queue, _ := GetQueue(ctx)
|
||
if msg.Topic == "" && queue != "" {
|
||
msg.Topic = queue
|
||
}
|
||
result := c.handler(ctx, msg)
|
||
result.Topic = msg.Topic
|
||
result.TaskID = msg.ID
|
||
return result
|
||
}
|
||
|
||
func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
||
if result.Status == "PENDING" && c.opts.respondPendingResult {
|
||
return nil
|
||
}
|
||
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, result.Topic)
|
||
if result.Status == "" {
|
||
if result.Error != nil {
|
||
result.Status = "FAILED"
|
||
} else {
|
||
result.Status = "SUCCESS"
|
||
}
|
||
}
|
||
bt, _ := json.Marshal(result)
|
||
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
||
if err := c.send(ctx, c.conn, reply); err != nil {
|
||
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
|
||
headers := HeadersWithConsumerID(ctx, c.id)
|
||
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
||
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
|
||
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
|
||
}
|
||
}
|
||
|
||
func (c *Consumer) attemptConnect() error {
|
||
var err error
|
||
delay := c.opts.initialDelay
|
||
for i := 0; i < c.opts.maxRetries; i++ {
|
||
conn, err := GetConnection(c.opts.brokerAddr, c.opts.tlsConfig)
|
||
if err == nil {
|
||
c.conn = conn
|
||
return nil
|
||
}
|
||
sleepDuration := utils.CalculateJitter(delay, c.opts.jitterPercent)
|
||
log.Printf("CONSUMER - SUBSCRIBE ~> Failed connecting to %s (attempt %d/%d): %v, Retrying in %v...\n", c.opts.brokerAddr, i+1, c.opts.maxRetries, err, sleepDuration)
|
||
time.Sleep(sleepDuration)
|
||
delay *= 2
|
||
if delay > c.opts.maxBackoff {
|
||
delay = c.opts.maxBackoff
|
||
}
|
||
}
|
||
|
||
return fmt.Errorf("could not connect to server %s after %d attempts: %w", c.opts.brokerAddr, c.opts.maxRetries, err)
|
||
}
|
||
|
||
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||
msg, err := c.receive(ctx, conn)
|
||
if err == nil {
|
||
ctx = SetHeaders(ctx, msg.Headers)
|
||
return c.OnMessage(ctx, msg, conn)
|
||
}
|
||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||
err1 := c.OnClose(ctx, conn)
|
||
if err1 != nil {
|
||
return err1
|
||
}
|
||
return err
|
||
}
|
||
c.OnError(ctx, conn, err)
|
||
return err
|
||
}
|
||
|
||
func (c *Consumer) Consume(ctx context.Context) error {
|
||
err := c.attemptConnect()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
c.pool = NewPool(
|
||
c.opts.numOfWorkers,
|
||
WithTaskQueueSize(c.opts.queueSize),
|
||
WithMaxMemoryLoad(c.opts.maxMemoryLoad),
|
||
WithHandler(c.ProcessTask),
|
||
WithPoolCallback(c.OnResponse),
|
||
WithTaskStorage(c.opts.storage),
|
||
)
|
||
if err := c.subscribe(ctx, c.queue); err != nil {
|
||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||
}
|
||
c.pool.Start(c.opts.numOfWorkers)
|
||
// Infinite loop to continuously read messages and reconnect if needed.
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
log.Println("Context canceled, stopping consumer.")
|
||
return nil
|
||
default:
|
||
if c.opts.ConsumerRateLimiter != nil {
|
||
c.opts.ConsumerRateLimiter.Wait()
|
||
}
|
||
if err := c.readMessage(ctx, c.conn); err != nil {
|
||
log.Printf("Error reading message: %v, attempting reconnection...", err)
|
||
// Attempt reconnection loop.
|
||
for {
|
||
if ctx.Err() != nil {
|
||
return nil
|
||
}
|
||
if rErr := c.attemptConnect(); rErr != nil {
|
||
log.Printf("Reconnection attempt failed: %v", rErr)
|
||
time.Sleep(c.opts.initialDelay)
|
||
} else {
|
||
break
|
||
}
|
||
}
|
||
if err := c.subscribe(ctx, c.queue); err != nil {
|
||
log.Printf("Failed to re-subscribe on reconnection: %v", err)
|
||
time.Sleep(c.opts.initialDelay)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
|
||
msg, err := c.receive(ctx, conn)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if msg.Command == consts.SUBSCRIBE_ACK {
|
||
log.Printf("CONSUMER - SUBSCRIBE_ACK ~> %s on %s", c.id, msg.Queue)
|
||
return nil
|
||
}
|
||
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
||
}
|
||
|
||
func (c *Consumer) Pause(ctx context.Context) error {
|
||
return c.operate(ctx, consts.CONSUMER_PAUSED, c.pool.Pause)
|
||
}
|
||
|
||
func (c *Consumer) Resume(ctx context.Context) error {
|
||
return c.operate(ctx, consts.CONSUMER_RESUMED, c.pool.Resume)
|
||
}
|
||
|
||
func (c *Consumer) Stop(ctx context.Context) error {
|
||
return c.operate(ctx, consts.CONSUMER_STOPPED, c.pool.Stop)
|
||
}
|
||
|
||
func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation func()) error {
|
||
if err := c.sendOpsMessage(ctx, cmd); err != nil {
|
||
return err
|
||
}
|
||
poolOperation()
|
||
return nil
|
||
}
|
||
|
||
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
||
headers := HeadersWithConsumerID(ctx, c.id)
|
||
msg := codec.NewMessage(cmd, nil, c.queue, headers)
|
||
return c.send(ctx, c.conn, msg)
|
||
}
|
||
|
||
func (c *Consumer) Conn() net.Conn {
|
||
return c.conn
|
||
}
|
||
|
||
type Publisher struct {
|
||
opts *Options
|
||
id string
|
||
conn net.Conn
|
||
connLock sync.Mutex
|
||
}
|
||
|
||
func NewPublisher(id string, opts ...Option) *Publisher {
|
||
options := SetupOptions(opts...)
|
||
return &Publisher{
|
||
id: id,
|
||
opts: options,
|
||
conn: nil,
|
||
}
|
||
}
|
||
|
||
// New method to ensure a persistent connection.
|
||
func (p *Publisher) ensureConnection(ctx context.Context) error {
|
||
p.connLock.Lock()
|
||
defer p.connLock.Unlock()
|
||
if p.conn != nil {
|
||
return nil
|
||
}
|
||
var err error
|
||
delay := p.opts.initialDelay
|
||
for i := 0; i < p.opts.maxRetries; i++ {
|
||
var conn net.Conn
|
||
conn, err = GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||
if err == nil {
|
||
p.conn = conn
|
||
return nil
|
||
}
|
||
sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent)
|
||
log.Printf("PUBLISHER - ensureConnection failed: %v, attempt %d/%d, retrying in %v...", err, i+1, p.opts.maxRetries, sleepDuration)
|
||
time.Sleep(sleepDuration)
|
||
delay *= 2
|
||
if delay > p.opts.maxBackoff {
|
||
delay = p.opts.maxBackoff
|
||
}
|
||
}
|
||
return fmt.Errorf("failed to connect to broker after retries: %w", err)
|
||
}
|
||
|
||
// Modified Publish method that uses the persistent connection.
|
||
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
|
||
// Ensure connection is established.
|
||
if err := p.ensureConnection(ctx); err != nil {
|
||
return err
|
||
}
|
||
delay := p.opts.initialDelay
|
||
for i := 0; i < p.opts.maxRetries; i++ {
|
||
// Use the persistent connection.
|
||
p.connLock.Lock()
|
||
conn := p.conn
|
||
p.connLock.Unlock()
|
||
err := p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
log.Printf("PUBLISHER - Failed publishing: %v, attempt %d/%d, retrying...", err, i+1, p.opts.maxRetries)
|
||
// On error, close and reset the connection.
|
||
p.connLock.Lock()
|
||
if p.conn != nil {
|
||
p.conn.Close()
|
||
p.conn = nil
|
||
}
|
||
p.connLock.Unlock()
|
||
sleepDuration := utils.CalculateJitter(delay, p.opts.jitterPercent)
|
||
time.Sleep(sleepDuration)
|
||
delay *= 2
|
||
if delay > p.opts.maxBackoff {
|
||
delay = p.opts.maxBackoff
|
||
}
|
||
// Ensure connection is re-established.
|
||
if err := p.ensureConnection(ctx); err != nil {
|
||
return err
|
||
}
|
||
}
|
||
return fmt.Errorf("failed to publish after retries")
|
||
}
|
||
|
||
func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.Conn, command consts.CMD) error {
|
||
headers := WithHeaders(ctx, map[string]string{
|
||
consts.PublisherKey: p.id,
|
||
consts.ContentType: consts.TypeJson,
|
||
})
|
||
if task.ID == "" {
|
||
task.ID = NewID()
|
||
}
|
||
task.CreatedAt = time.Now()
|
||
payload, err := json.Marshal(task)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
msg := codec.NewMessage(command, payload, queue, headers)
|
||
if err := codec.SendMessage(ctx, conn, msg); err != nil {
|
||
return err
|
||
}
|
||
|
||
return p.waitForAck(ctx, conn)
|
||
}
|
||
|
||
func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error {
|
||
msg, err := codec.ReadMessage(ctx, conn)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if msg.Command == consts.PUBLISH_ACK {
|
||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||
log.Printf("PUBLISHER - PUBLISH_ACK ~> from %s on %s for Task %s", p.id, msg.Queue, taskID)
|
||
return nil
|
||
}
|
||
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||
}
|
||
|
||
func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result {
|
||
msg, err := codec.ReadMessage(ctx, conn)
|
||
if err != nil {
|
||
return Result{Error: err}
|
||
}
|
||
if msg.Command == consts.RESPONSE {
|
||
var result Result
|
||
err = json.Unmarshal(msg.Payload, &result)
|
||
return result
|
||
}
|
||
err = fmt.Errorf("expected RESPONSE, got: %v", msg.Command)
|
||
return Result{Error: err}
|
||
}
|
||
|
||
func (p *Publisher) onClose(_ context.Context, conn net.Conn) error {
|
||
fmt.Println("Publisher Connection closed", p.id, conn.RemoteAddr())
|
||
return nil
|
||
}
|
||
|
||
func (p *Publisher) onError(_ context.Context, conn net.Conn, err error) {
|
||
fmt.Println("Error reading from publisher connection:", err, conn.RemoteAddr())
|
||
}
|
||
|
||
func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result {
|
||
ctx = SetHeaders(ctx, map[string]string{
|
||
consts.AwaitResponseKey: "true",
|
||
})
|
||
conn, err := GetConnection(p.opts.brokerAddr, p.opts.tlsConfig)
|
||
if err != nil {
|
||
err = fmt.Errorf("failed to connect to broker: %w", err)
|
||
return Result{Error: err}
|
||
}
|
||
defer func() {
|
||
_ = conn.Close()
|
||
}()
|
||
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||
resultCh := make(chan Result)
|
||
go func() {
|
||
defer close(resultCh)
|
||
resultCh <- p.waitForResponse(ctx, conn)
|
||
}()
|
||
finalResult := <-resultCh
|
||
return finalResult
|
||
}
|
||
|
||
type Queue struct {
|
||
consumers storage.IMap[string, *consumer]
|
||
tasks chan *QueuedTask // channel to hold tasks
|
||
name string
|
||
}
|
||
|
||
func newQueue(name string, queueSize int) *Queue {
|
||
return &Queue{
|
||
name: name,
|
||
consumers: memory.New[string, *consumer](),
|
||
tasks: make(chan *QueuedTask, queueSize), // buffer size for tasks
|
||
}
|
||
}
|
||
|
||
func (b *Broker) NewQueue(name string) *Queue {
|
||
q := &Queue{
|
||
name: name,
|
||
tasks: make(chan *QueuedTask, b.opts.queueSize),
|
||
consumers: memory.New[string, *consumer](),
|
||
}
|
||
b.queues.Set(name, q)
|
||
|
||
// Create DLQ for the queue
|
||
dlq := &Queue{
|
||
name: name + "_dlq",
|
||
tasks: make(chan *QueuedTask, b.opts.queueSize),
|
||
consumers: memory.New[string, *consumer](),
|
||
}
|
||
b.deadLetter.Set(name, dlq)
|
||
ctx := context.Background()
|
||
go b.dispatchWorker(ctx, q)
|
||
go b.dispatchWorker(ctx, dlq)
|
||
return q
|
||
}
|
||
|
||
type QueueTask struct {
|
||
ctx context.Context
|
||
payload *Task
|
||
priority int
|
||
retryCount int
|
||
index int
|
||
}
|
||
|
||
type PriorityQueue []*QueueTask
|
||
|
||
func (pq PriorityQueue) Len() int { return len(pq) }
|
||
func (pq PriorityQueue) Less(i, j int) bool {
|
||
return pq[i].priority > pq[j].priority
|
||
}
|
||
func (pq PriorityQueue) Swap(i, j int) {
|
||
pq[i], pq[j] = pq[j], pq[i]
|
||
pq[i].index = i
|
||
pq[j].index = j
|
||
}
|
||
func (pq *PriorityQueue) Push(x interface{}) {
|
||
n := len(*pq)
|
||
task := x.(*QueueTask)
|
||
task.index = n
|
||
*pq = append(*pq, task)
|
||
}
|
||
func (pq *PriorityQueue) Pop() interface{} {
|
||
old := *pq
|
||
n := len(old)
|
||
task := old[n-1]
|
||
task.index = -1
|
||
*pq = old[0 : n-1]
|
||
return task
|
||
}
|
||
|
||
type Task struct {
|
||
CreatedAt time.Time `json:"created_at"`
|
||
ProcessedAt time.Time `json:"processed_at"`
|
||
Expiry time.Time `json:"expiry"`
|
||
Error error `json:"error"`
|
||
ID string `json:"id"`
|
||
Topic string `json:"topic"`
|
||
Status string `json:"status"`
|
||
Payload json.RawMessage `json:"payload"`
|
||
dag any
|
||
}
|
||
|
||
func (t *Task) GetFlow() any {
|
||
return t.dag
|
||
}
|
||
|
||
func NewTask(id string, payload json.RawMessage, nodeKey string, opts ...TaskOption) *Task {
|
||
if id == "" {
|
||
id = NewID()
|
||
}
|
||
task := &Task{ID: id, Payload: payload, Topic: nodeKey, CreatedAt: time.Now()}
|
||
for _, opt := range opts {
|
||
opt(task)
|
||
}
|
||
return task
|
||
}
|
||
|
||
// TaskOption defines a function type for setting options.
|
||
type TaskOption func(*Task)
|
||
|
||
func WithDAG(dag any) TaskOption {
|
||
return func(opts *Task) {
|
||
opts.dag = dag
|
||
}
|
||
}
|
||
|
||
func (b *Broker) TLSConfig() TLSConfig {
|
||
return b.opts.tlsConfig
|
||
}
|
||
|
||
func (b *Broker) SyncMode() bool {
|
||
return b.opts.syncMode
|
||
}
|
||
|
||
func (b *Broker) NotifyHandler() func(context.Context, Result) error {
|
||
return b.opts.notifyResponse
|
||
}
|
||
|
||
func (b *Broker) SetNotifyHandler(callback Callback) {
|
||
b.opts.notifyResponse = callback
|
||
}
|
||
|
||
func (b *Broker) HandleCallback(ctx context.Context, msg *codec.Message) {
|
||
if b.opts.callback != nil {
|
||
var result Result
|
||
err := json.Unmarshal(msg.Payload, &result)
|
||
if err == nil {
|
||
for _, callback := range b.opts.callback {
|
||
callback(ctx, result)
|
||
}
|
||
}
|
||
}
|
||
}
|