mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-01 06:12:10 +08:00
feat: add example
This commit is contained in:
122
broker.go
122
broker.go
@@ -25,6 +25,7 @@ type QueuedTask struct {
|
||||
|
||||
type consumer struct {
|
||||
id string
|
||||
state consts.ConsumerState
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
@@ -50,6 +51,10 @@ func NewBroker(opts ...Option) *Broker {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) Options() Options {
|
||||
return b.opts
|
||||
}
|
||||
|
||||
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
if ok && consumerID != "" {
|
||||
@@ -110,6 +115,16 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con
|
||||
b.MessageResponseHandler(ctx, msg)
|
||||
case consts.MESSAGE_ACK:
|
||||
b.MessageAck(ctx, msg)
|
||||
case consts.MESSAGE_DENY:
|
||||
b.MessageDeny(ctx, msg)
|
||||
case consts.CONSUMER_PAUSED:
|
||||
b.OnConsumerPause(ctx, msg)
|
||||
case consts.CONSUMER_RESUMED:
|
||||
b.OnConsumerResume(ctx, msg)
|
||||
case consts.CONSUMER_STOPPED:
|
||||
b.OnConsumerStop(ctx, msg)
|
||||
default:
|
||||
log.Printf("BROKER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,6 +134,43 @@ func (b *Broker) MessageAck(ctx context.Context, msg *codec.Message) {
|
||||
log.Printf("BROKER - MESSAGE_ACK ~> %s on %s for Task %s", consumerID, msg.Queue, taskID)
|
||||
}
|
||||
|
||||
func (b *Broker) MessageDeny(ctx context.Context, msg *codec.Message) {
|
||||
consumerID, _ := GetConsumerID(ctx)
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
taskError, _ := jsonparser.GetString(msg.Payload, "error")
|
||||
log.Printf("BROKER - MESSAGE_DENY ~> %s on %s for Task %s, Error: %s", consumerID, msg.Queue, taskID, taskError)
|
||||
}
|
||||
|
||||
func (b *Broker) OnConsumerPause(ctx context.Context, msg *codec.Message) {
|
||||
consumerID, _ := GetConsumerID(ctx)
|
||||
if consumerID != "" {
|
||||
if con, exists := b.consumers.Get(consumerID); exists {
|
||||
con.state = consts.ConsumerStatePaused
|
||||
log.Printf("BROKER - CONSUMER ~> Paused %s", consumerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) OnConsumerStop(ctx context.Context, msg *codec.Message) {
|
||||
consumerID, _ := GetConsumerID(ctx)
|
||||
if consumerID != "" {
|
||||
if con, exists := b.consumers.Get(consumerID); exists {
|
||||
con.state = consts.ConsumerStateStopped
|
||||
log.Printf("BROKER - CONSUMER ~> Stopped %s", consumerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) OnConsumerResume(ctx context.Context, msg *codec.Message) {
|
||||
consumerID, _ := GetConsumerID(ctx)
|
||||
if consumerID != "" {
|
||||
if con, exists := b.consumers.Get(consumerID); exists {
|
||||
con.state = consts.ConsumerStateActive
|
||||
log.Printf("BROKER - CONSUMER ~> Resumed %s", consumerID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
|
||||
msg.Command = consts.RESPONSE
|
||||
b.HandleCallback(ctx, msg)
|
||||
@@ -170,7 +222,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
|
||||
}
|
||||
|
||||
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
consumerID := b.addConsumer(ctx, msg.Queue, conn)
|
||||
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
|
||||
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||||
if err := b.send(conn, ack); err != nil {
|
||||
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||
@@ -181,7 +233,7 @@ func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.removeConsumer(msg.Queue, consumerID)
|
||||
b.RemoveConsumer(consumerID, msg.Queue)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -267,7 +319,7 @@ func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Co
|
||||
return con
|
||||
}
|
||||
|
||||
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||||
func (b *Broker) AddConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
q, ok := b.queues.Get(queueName)
|
||||
if !ok {
|
||||
@@ -280,7 +332,9 @@ func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Con
|
||||
return consumerID
|
||||
}
|
||||
|
||||
func (b *Broker) removeConsumer(queueName, consumerID string) {
|
||||
func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
|
||||
if len(queues) > 0 {
|
||||
for _, queueName := range queues {
|
||||
if queue, ok := b.queues.Get(queueName); ok {
|
||||
con, ok := queue.consumers.Get(consumerID)
|
||||
if ok {
|
||||
@@ -290,6 +344,55 @@ func (b *Broker) removeConsumer(queueName, consumerID string) {
|
||||
b.queues.Del(queueName)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
b.queues.ForEach(func(queueName string, queue *Queue) bool {
|
||||
con, ok := queue.consumers.Get(consumerID)
|
||||
if ok {
|
||||
con.conn.Close()
|
||||
queue.consumers.Del(consumerID)
|
||||
}
|
||||
b.queues.Del(queueName)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
|
||||
fn := func(queue *Queue) {
|
||||
con, ok := queue.consumers.Get(consumerID)
|
||||
if ok {
|
||||
ack := codec.NewMessage(cmd, []byte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||||
err := b.send(con.conn, ack)
|
||||
if err == nil {
|
||||
con.state = state
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(queues) > 0 {
|
||||
for _, queueName := range queues {
|
||||
if queue, ok := b.queues.Get(queueName); ok {
|
||||
fn(queue)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
b.queues.ForEach(func(queueName string, queue *Queue) bool {
|
||||
fn(queue)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Broker) PauseConsumer(consumerID string, queues ...string) {
|
||||
b.handleConsumer(consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
|
||||
}
|
||||
|
||||
func (b *Broker) ResumeConsumer(consumerID string, queues ...string) {
|
||||
b.handleConsumer(consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
|
||||
}
|
||||
|
||||
func (b *Broker) StopConsumer(consumerID string, queues ...string) {
|
||||
b.handleConsumer(consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
|
||||
}
|
||||
|
||||
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||
msg, err := b.receive(c)
|
||||
@@ -323,13 +426,22 @@ func (b *Broker) dispatchWorker(queue *Queue) {
|
||||
|
||||
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
||||
var consumerFound bool
|
||||
var err error
|
||||
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
||||
if con.state != consts.ConsumerStateActive {
|
||||
err = fmt.Errorf("consumer %s is not active", con.id)
|
||||
return false
|
||||
}
|
||||
if err := b.send(con.conn, task.Message); err == nil {
|
||||
consumerFound = true
|
||||
return false // break the loop once a consumer is found
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
log.Println(err.Error())
|
||||
return false
|
||||
}
|
||||
if !consumerFound {
|
||||
log.Printf("No available consumers for queue %s, retrying...", queue.name)
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
"github.com/oarkflow/mq/consts"
|
||||
)
|
||||
@@ -15,7 +16,7 @@ type Message struct {
|
||||
Queue string `json:"q"`
|
||||
Command consts.CMD `json:"c"`
|
||||
Payload json.RawMessage `json:"p"`
|
||||
// Metadata map[string]any `json:"m"`
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers map[string]string) *Message {
|
||||
@@ -24,14 +25,13 @@ func NewMessage(cmd consts.CMD, payload json.RawMessage, queue string, headers m
|
||||
Queue: queue,
|
||||
Command: cmd,
|
||||
Payload: payload,
|
||||
// Metadata: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, string, error) {
|
||||
m.m.Lock()
|
||||
defer m.m.Unlock()
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Serialize Headers, Topic, Command, Payload, and Metadata
|
||||
if err := writeLengthPrefixedJSON(&buf, m.Headers); err != nil {
|
||||
return nil, "", fmt.Errorf("error serializing headers: %v", err)
|
||||
}
|
||||
@@ -44,56 +44,37 @@ func (m *Message) Serialize(aesKey, hmacKey []byte, encrypt bool) ([]byte, strin
|
||||
if err := writePayload(&buf, aesKey, m.Payload, encrypt); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
/*if err := writeLengthPrefixedJSON(&buf, m.Metadata); err != nil {
|
||||
return nil, "", fmt.Errorf("error serializing metadata: %v", err)
|
||||
}*/
|
||||
|
||||
// Calculate HMAC
|
||||
messageBytes := buf.Bytes()
|
||||
hmacSignature := CalculateHMAC(hmacKey, messageBytes)
|
||||
|
||||
return messageBytes, hmacSignature, nil
|
||||
}
|
||||
|
||||
func Deserialize(data, aesKey, hmacKey []byte, receivedHMAC string, decrypt bool) (*Message, error) {
|
||||
if !VerifyHMAC(hmacKey, data, receivedHMAC) {
|
||||
return nil, fmt.Errorf("HMAC verification failed")
|
||||
return nil, fmt.Errorf("HMAC verification failed %s", string(hmacKey))
|
||||
}
|
||||
|
||||
buf := bytes.NewReader(data)
|
||||
|
||||
// Deserialize Headers, Topic, Command, Payload, and Metadata
|
||||
headers := make(map[string]string)
|
||||
if err := readLengthPrefixedJSON(buf, &headers); err != nil {
|
||||
return nil, fmt.Errorf("error deserializing headers: %v", err)
|
||||
}
|
||||
|
||||
topic, err := readLengthPrefixedString(buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deserializing topic: %v", err)
|
||||
}
|
||||
|
||||
var command consts.CMD
|
||||
if err := binary.Read(buf, binary.LittleEndian, &command); err != nil {
|
||||
return nil, fmt.Errorf("error deserializing command: %v", err)
|
||||
}
|
||||
|
||||
payload, err := readPayload(buf, aesKey, decrypt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deserializing payload: %v", err)
|
||||
}
|
||||
|
||||
/*metadata := make(map[string]any)
|
||||
if err := readLengthPrefixedJSON(buf, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("error deserializing metadata: %v", err)
|
||||
}*/
|
||||
|
||||
return &Message{
|
||||
Headers: headers,
|
||||
Queue: topic,
|
||||
Command: command,
|
||||
Payload: payload,
|
||||
// Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -102,11 +83,9 @@ func SendMessage(conn io.Writer, msg *Message, aesKey, hmacKey []byte, encrypt b
|
||||
if err != nil {
|
||||
return fmt.Errorf("error serializing message: %v", err)
|
||||
}
|
||||
|
||||
if err := writeMessageWithHMAC(conn, sentData, hmacSignature); err != nil {
|
||||
return fmt.Errorf("error sending message: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -115,7 +94,6 @@ func ReadMessage(conn io.Reader, aesKey, hmacKey []byte, decrypt bool) (*Message
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Deserialize(data, aesKey, hmacKey, receivedHMAC, decrypt)
|
||||
}
|
||||
|
||||
@@ -195,7 +173,6 @@ func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var payloadBytes []byte
|
||||
if decrypt {
|
||||
nonce := make([]byte, 12)
|
||||
@@ -209,12 +186,10 @@ func readPayload(r *bytes.Reader, aesKey []byte, decrypt bool) (json.RawMessage,
|
||||
} else {
|
||||
payloadBytes = encryptedPayload
|
||||
}
|
||||
|
||||
var payload json.RawMessage
|
||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling payload: %v", err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
||||
func writeMessageWithHMAC(conn io.Writer, messageBytes []byte, hmacSignature string) error {
|
||||
|
@@ -2,7 +2,7 @@ package consts
|
||||
|
||||
type CMD byte
|
||||
|
||||
func (c CMD) IsValid() bool { return c >= PING && c <= STOP }
|
||||
func (c CMD) IsValid() bool { return c >= PING && c <= CONSUMER_STOP }
|
||||
|
||||
const (
|
||||
PING CMD = iota + 1
|
||||
@@ -11,13 +11,29 @@ const (
|
||||
|
||||
MESSAGE_SEND
|
||||
MESSAGE_RESPONSE
|
||||
MESSAGE_DENY
|
||||
MESSAGE_ACK
|
||||
MESSAGE_ERROR
|
||||
|
||||
PUBLISH
|
||||
PUBLISH_ACK
|
||||
RESPONSE
|
||||
STOP
|
||||
|
||||
CONSUMER_PAUSE
|
||||
CONSUMER_RESUME
|
||||
CONSUMER_STOP
|
||||
|
||||
CONSUMER_PAUSED
|
||||
CONSUMER_RESUMED
|
||||
CONSUMER_STOPPED
|
||||
)
|
||||
|
||||
type ConsumerState byte
|
||||
|
||||
const (
|
||||
ConsumerStateActive ConsumerState = iota
|
||||
ConsumerStatePaused
|
||||
ConsumerStateStopped
|
||||
)
|
||||
|
||||
func (c CMD) String() string {
|
||||
@@ -30,6 +46,8 @@ func (c CMD) String() string {
|
||||
return "SUBSCRIBE_ACK"
|
||||
case MESSAGE_SEND:
|
||||
return "MESSAGE_SEND"
|
||||
case MESSAGE_DENY:
|
||||
return "MESSAGE_DENY"
|
||||
case MESSAGE_RESPONSE:
|
||||
return "MESSAGE_RESPONSE"
|
||||
case MESSAGE_ERROR:
|
||||
@@ -40,8 +58,18 @@ func (c CMD) String() string {
|
||||
return "PUBLISH"
|
||||
case PUBLISH_ACK:
|
||||
return "PUBLISH_ACK"
|
||||
case STOP:
|
||||
return "STOP"
|
||||
case CONSUMER_PAUSE:
|
||||
return "CONSUMER_PAUSE"
|
||||
case CONSUMER_RESUME:
|
||||
return "CONSUMER_RESUME"
|
||||
case CONSUMER_STOP:
|
||||
return "CONSUMER_STOP"
|
||||
case CONSUMER_PAUSED:
|
||||
return "CONSUMER_PAUSED"
|
||||
case CONSUMER_RESUMED:
|
||||
return "CONSUMER_RESUMED"
|
||||
case CONSUMER_STOPPED:
|
||||
return "CONSUMER_STOPPED"
|
||||
case RESPONSE:
|
||||
return "RESPONSE"
|
||||
default:
|
||||
|
104
consumer.go
104
consumer.go
@@ -23,6 +23,7 @@ type Consumer struct {
|
||||
conn net.Conn
|
||||
queue string
|
||||
opts Options
|
||||
pool *Pool
|
||||
}
|
||||
|
||||
// NewConsumer initializes a new consumer with the provided options.
|
||||
@@ -44,8 +45,9 @@ func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(conn, c.opts.aesKey, c.opts.hmacKey, c.opts.enableEncryption)
|
||||
}
|
||||
|
||||
// Close closes the consumer's connection.
|
||||
// Close closes the consumer's connection and stops the worker pool.
|
||||
func (c *Consumer) Close() error {
|
||||
c.pool.Stop()
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
@@ -55,11 +57,10 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||
consts.ConsumerKey: c.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
msg := codec.NewMessage(consts.SUBSCRIBE, nil, queue, headers)
|
||||
msg := codec.NewMessage(consts.SUBSCRIBE, []byte("{}"), queue, headers)
|
||||
if err := c.send(c.conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.waitForAck(c.conn)
|
||||
}
|
||||
|
||||
@@ -73,6 +74,30 @@ func (c *Consumer) OnError(_ context.Context, conn net.Conn, err error) {
|
||||
}
|
||||
|
||||
func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
switch msg.Command {
|
||||
case consts.PUBLISH:
|
||||
c.ConsumeMessage(ctx, msg, conn)
|
||||
case consts.CONSUMER_PAUSE:
|
||||
err := c.Pause()
|
||||
if err != nil {
|
||||
log.Printf("Unable to pause consumer: %v", err)
|
||||
}
|
||||
case consts.CONSUMER_RESUME:
|
||||
err := c.Resume()
|
||||
if err != nil {
|
||||
log.Printf("Unable to resume consumer: %v", err)
|
||||
}
|
||||
case consts.CONSUMER_STOP:
|
||||
err := c.Stop()
|
||||
if err != nil {
|
||||
log.Printf("Unable to stop consumer: %v", err)
|
||||
}
|
||||
default:
|
||||
log.Printf("CONSUMER - UNKNOWN_COMMAND ~> %s on %s", msg.Command, msg.Queue)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) ConsumeMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
headers := WithHeaders(ctx, map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
@@ -89,12 +114,21 @@ func (c *Consumer) OnMessage(ctx context.Context, msg *codec.Message, conn net.C
|
||||
log.Printf("Error unmarshalling message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx = SetHeaders(ctx, map[string]string{consts.QueueKey: msg.Queue})
|
||||
if !c.opts.enableWorkerPool {
|
||||
result := c.ProcessTask(ctx, &task)
|
||||
err = c.OnResponse(ctx, result)
|
||||
if err != nil {
|
||||
log.Printf("Error on message callback: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Add the task to the worker pool
|
||||
if err := c.pool.AddTask(ctx, &task); err != nil {
|
||||
c.sendDenyMessage(ctx, taskID, msg.Queue, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// OnResponse sends the result back to the broker.
|
||||
@@ -120,6 +154,17 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
|
||||
headers := WithHeaders(ctx, map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
})
|
||||
reply := codec.NewMessage(consts.MESSAGE_DENY, []byte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
||||
if sendErr := c.send(c.conn, reply); sendErr != nil {
|
||||
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessTask handles a received task message and invokes the appropriate handler.
|
||||
func (c *Consumer) ProcessTask(ctx context.Context, msg *Task) Result {
|
||||
result := c.handler(ctx, msg)
|
||||
@@ -171,10 +216,11 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.pool = NewPool(c.opts.numOfWorkers, c.opts.queueSize, c.opts.maxMemoryLoad, c.ProcessTask, c.OnResponse, c.conn)
|
||||
if err := c.subscribe(ctx, c.queue); err != nil {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||||
}
|
||||
c.pool.Start(c.opts.numOfWorkers)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -202,3 +248,53 @@ func (c *Consumer) waitForAck(conn net.Conn) error {
|
||||
}
|
||||
return fmt.Errorf("expected SUBSCRIBE_ACK, got: %v", msg.Command)
|
||||
}
|
||||
|
||||
// Additional methods for Pause, Resume, and Stop
|
||||
|
||||
func (c *Consumer) Pause() error {
|
||||
if err := c.sendPauseMessage(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.pool.Pause()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) sendPauseMessage() error {
|
||||
headers := WithHeaders(context.Background(), map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
})
|
||||
msg := codec.NewMessage(consts.CONSUMER_PAUSED, nil, c.queue, headers)
|
||||
return c.send(c.conn, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) Resume() error {
|
||||
if err := c.sendResumeMessage(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.pool.Resume()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) sendResumeMessage() error {
|
||||
headers := WithHeaders(context.Background(), map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
})
|
||||
msg := codec.NewMessage(consts.CONSUMER_RESUMED, nil, c.queue, headers)
|
||||
return c.send(c.conn, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) Stop() error {
|
||||
if err := c.sendStopMessage(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.pool.Stop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) sendStopMessage() error {
|
||||
headers := WithHeaders(context.Background(), map[string]string{
|
||||
consts.ConsumerKey: c.id,
|
||||
})
|
||||
msg := codec.NewMessage(consts.CONSUMER_STOPPED, nil, c.queue, headers)
|
||||
return c.send(c.conn, msg)
|
||||
}
|
||||
|
44
dag/dag.go
44
dag/dag.go
@@ -4,12 +4,13 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/oarkflow/xid"
|
||||
"log"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/xid"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
@@ -49,6 +50,8 @@ type DAG struct {
|
||||
taskContext map[string]*TaskManager
|
||||
conditions map[string]map[string]string
|
||||
mu sync.RWMutex
|
||||
paused bool
|
||||
opts []mq.Option
|
||||
}
|
||||
|
||||
func NewDAG(opts ...mq.Option) *DAG {
|
||||
@@ -59,6 +62,7 @@ func NewDAG(opts ...mq.Option) *DAG {
|
||||
}
|
||||
opts = append(opts, mq.WithCallback(d.onTaskCallback), mq.WithConsumerOnSubscribe(d.onConsumerJoin), mq.WithConsumerOnClose(d.onConsumerClose))
|
||||
d.server = mq.NewBroker(opts...)
|
||||
d.opts = opts
|
||||
return d
|
||||
}
|
||||
|
||||
@@ -95,10 +99,13 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
|
||||
if con.isReady {
|
||||
go func(con *Node) {
|
||||
time.Sleep(1 * time.Second)
|
||||
con.consumer.Consume(ctx)
|
||||
err := con.consumer.Consume(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(con)
|
||||
} else {
|
||||
log.Printf("[WARNING] - %s is not ready yet", con.Key)
|
||||
log.Printf("[WARNING] - Consumer %s is not ready yet", con.Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -114,7 +121,7 @@ func (tm *DAG) Start(ctx context.Context, addr string) error {
|
||||
func (tm *DAG) AddNode(key string, handler mq.Handler, firstNode ...bool) {
|
||||
tm.mu.Lock()
|
||||
defer tm.mu.Unlock()
|
||||
con := mq.NewConsumer(key, key, handler)
|
||||
con := mq.NewConsumer(key, key, handler, tm.opts...)
|
||||
tm.Nodes[key] = &Node{
|
||||
Key: key,
|
||||
consumer: con,
|
||||
@@ -176,8 +183,11 @@ func (tm *DAG) AddEdge(from, to string, edgeTypes ...EdgeType) {
|
||||
}
|
||||
|
||||
func (tm *DAG) ProcessTask(ctx context.Context, payload []byte) mq.Result {
|
||||
if tm.paused {
|
||||
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not accepting any task")}
|
||||
}
|
||||
if !tm.IsReady() {
|
||||
return mq.Result{Error: fmt.Errorf("DAG is not ready yet")}
|
||||
return mq.Result{Error: fmt.Errorf("unable to process task, error: DAG is not ready yet")}
|
||||
}
|
||||
val := ctx.Value("initial_node")
|
||||
initialNode, ok := val.(string)
|
||||
@@ -226,3 +236,27 @@ func (tm *DAG) FindInitialNode() *Node {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tm *DAG) Pause() {
|
||||
tm.paused = true
|
||||
log.Printf("DAG - PAUSED")
|
||||
}
|
||||
|
||||
func (tm *DAG) Resume() {
|
||||
tm.paused = false
|
||||
log.Printf("DAG - RESUMED")
|
||||
}
|
||||
|
||||
func (tm *DAG) PauseConsumer(id string) {
|
||||
if node, ok := tm.Nodes[id]; ok {
|
||||
node.consumer.Pause()
|
||||
node.isReady = false
|
||||
}
|
||||
}
|
||||
|
||||
func (tm *DAG) ResumeConsumer(id string) {
|
||||
if node, ok := tm.Nodes[id]; ok {
|
||||
node.consumer.Resume()
|
||||
node.isReady = true
|
||||
}
|
||||
}
|
||||
|
@@ -15,7 +15,11 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
d = dag.NewDAG(mq.WithSyncMode(false), mq.WithNotifyResponse(tasks.NotifyResponse))
|
||||
d = dag.NewDAG(
|
||||
mq.WithNotifyResponse(tasks.NotifyResponse),
|
||||
mq.WithWorkerPool(100, 4, 5000000),
|
||||
mq.WithSecretKey([]byte("wKWa6GKdBd0njDKNQoInBbh6P0KTjmob")),
|
||||
)
|
||||
// d = dag.NewDAG(mq.WithSyncMode(true), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
||||
)
|
||||
|
||||
@@ -34,6 +38,24 @@ func main() {
|
||||
d.AddEdge("E", "F")
|
||||
http.HandleFunc("POST /publish", requestHandler("publish"))
|
||||
http.HandleFunc("POST /request", requestHandler("request"))
|
||||
http.HandleFunc("/pause-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
id := request.PathValue("id")
|
||||
if id != "" {
|
||||
d.PauseConsumer(id)
|
||||
}
|
||||
})
|
||||
http.HandleFunc("/resume-consumer/{id}", func(writer http.ResponseWriter, request *http.Request) {
|
||||
id := request.PathValue("id")
|
||||
if id != "" {
|
||||
d.ResumeConsumer(id)
|
||||
}
|
||||
})
|
||||
http.HandleFunc("/pause", func(writer http.ResponseWriter, request *http.Request) {
|
||||
d.Pause()
|
||||
})
|
||||
http.HandleFunc("/resume", func(writer http.ResponseWriter, request *http.Request) {
|
||||
d.Resume()
|
||||
})
|
||||
err := d.Start(context.TODO(), ":8083")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
34
examples/hmac.go
Normal file
34
examples/hmac.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
func GenerateSecretKey() (string, error) {
|
||||
// Create a byte slice to hold 32 random bytes
|
||||
key := make([]byte, 32)
|
||||
|
||||
// Fill the slice with secure random bytes
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Encode the byte slice to a Base64 string
|
||||
secretKey := base64.StdEncoding.EncodeToString(key)
|
||||
|
||||
// Return the first 32 characters
|
||||
return secretKey[:32], nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
secretKey, err := GenerateSecretKey()
|
||||
if err != nil {
|
||||
log.Fatalf("Error generating secret key: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Generated Secret Key:", secretKey)
|
||||
}
|
@@ -9,17 +9,19 @@ import (
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
func Node1(ctx context.Context, task *mq.Task) mq.Result {
|
||||
func Node1(_ context.Context, task *mq.Task) mq.Result {
|
||||
fmt.Println("Node 1", string(task.Payload))
|
||||
return mq.Result{Payload: task.Payload, TaskID: task.ID}
|
||||
}
|
||||
|
||||
func Node2(ctx context.Context, task *mq.Task) mq.Result {
|
||||
func Node2(_ context.Context, task *mq.Task) mq.Result {
|
||||
fmt.Println("Node 2", string(task.Payload))
|
||||
return mq.Result{Payload: task.Payload, TaskID: task.ID}
|
||||
}
|
||||
|
||||
func Node3(ctx context.Context, task *mq.Task) mq.Result {
|
||||
func Node3(_ context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
_ = json.Unmarshal(task.Payload, &user)
|
||||
age := int(user["age"].(float64))
|
||||
status := "FAIL"
|
||||
if age > 20 {
|
||||
@@ -30,34 +32,34 @@ func Node3(ctx context.Context, task *mq.Task) mq.Result {
|
||||
return mq.Result{Payload: resultPayload, Status: status}
|
||||
}
|
||||
|
||||
func Node4(ctx context.Context, task *mq.Task) mq.Result {
|
||||
func Node4(_ context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
_ = json.Unmarshal(task.Payload, &user)
|
||||
user["final"] = "D"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func Node5(ctx context.Context, task *mq.Task) mq.Result {
|
||||
func Node5(_ context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
_ = json.Unmarshal(task.Payload, &user)
|
||||
user["salary"] = "E"
|
||||
resultPayload, _ := json.Marshal(user)
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func Node6(ctx context.Context, task *mq.Task) mq.Result {
|
||||
func Node6(_ context.Context, task *mq.Task) mq.Result {
|
||||
var user map[string]any
|
||||
json.Unmarshal(task.Payload, &user)
|
||||
_ = json.Unmarshal(task.Payload, &user)
|
||||
resultPayload, _ := json.Marshal(map[string]any{"storage": user})
|
||||
return mq.Result{Payload: resultPayload}
|
||||
}
|
||||
|
||||
func Callback(ctx context.Context, task mq.Result) mq.Result {
|
||||
func Callback(_ context.Context, task mq.Result) mq.Result {
|
||||
fmt.Println("Received task", task.TaskID, "Payload", string(task.Payload), task.Error, task.Topic)
|
||||
return mq.Result{}
|
||||
}
|
||||
|
||||
func NotifyResponse(ctx context.Context, result mq.Result) {
|
||||
log.Printf("DAG Final response: TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic)
|
||||
func NotifyResponse(_ context.Context, result mq.Result) {
|
||||
log.Printf("DAG - FINAL_RESPONSE ~> TaskID: %s, Payload: %s, Topic: %s", result.TaskID, result.Payload, result.Topic)
|
||||
}
|
||||
|
26
options.go
26
options.go
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -72,17 +73,22 @@ type Options struct {
|
||||
hmacKey json.RawMessage
|
||||
enableEncryption bool
|
||||
queueSize int
|
||||
numOfWorkers int
|
||||
maxMemoryLoad int64
|
||||
enableWorkerPool bool
|
||||
}
|
||||
|
||||
func defaultOptions() Options {
|
||||
return Options{
|
||||
syncMode: false,
|
||||
brokerAddr: ":8080",
|
||||
maxRetries: 5,
|
||||
initialDelay: 2 * time.Second,
|
||||
maxBackoff: 20 * time.Second,
|
||||
jitterPercent: 0.5,
|
||||
queueSize: 100,
|
||||
hmacKey: []byte(`a9f4b9415485b70275673b5920182796ea497b5e093ead844a43ea5d77cbc24f`),
|
||||
numOfWorkers: runtime.NumCPU(),
|
||||
maxMemoryLoad: 5000000,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,6 +109,15 @@ func WithNotifyResponse(handler func(ctx context.Context, result Result)) Option
|
||||
}
|
||||
}
|
||||
|
||||
func WithWorkerPool(queueSize, numOfWorkers int, maxMemoryLoad int64) Option {
|
||||
return func(opts *Options) {
|
||||
opts.enableWorkerPool = true
|
||||
opts.queueSize = queueSize
|
||||
opts.numOfWorkers = numOfWorkers
|
||||
opts.maxMemoryLoad = maxMemoryLoad
|
||||
}
|
||||
}
|
||||
|
||||
func WithConsumerOnSubscribe(handler func(ctx context.Context, topic, consumerName string)) Option {
|
||||
return func(opts *Options) {
|
||||
opts.consumerOnSubscribe = handler
|
||||
@@ -115,11 +130,16 @@ func WithConsumerOnClose(handler func(ctx context.Context, topic, consumerName s
|
||||
}
|
||||
}
|
||||
|
||||
func WithEncryption(aesKey, hmacKey json.RawMessage, enableEncryption bool) Option {
|
||||
func WithSecretKey(aesKey json.RawMessage) Option {
|
||||
return func(opts *Options) {
|
||||
opts.aesKey = aesKey
|
||||
opts.enableEncryption = true
|
||||
}
|
||||
}
|
||||
|
||||
func WithHMACKey(hmacKey json.RawMessage) Option {
|
||||
return func(opts *Options) {
|
||||
opts.hmacKey = hmacKey
|
||||
opts.enableEncryption = enableEncryption
|
||||
}
|
||||
}
|
||||
|
||||
|
43
utils/encrypt.go
Normal file
43
utils/encrypt.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func GenerateHMACKey(length int) (string, error) {
|
||||
key := make([]byte, length)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate random key: %v", err)
|
||||
}
|
||||
return hex.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
func MustGenerateHMACKey(length int) string {
|
||||
key, err := GenerateHMACKey(length)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func GenerateSecretKey(length int) (string, error) {
|
||||
key := make([]byte, length)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
secretKey := base64.StdEncoding.EncodeToString(key)
|
||||
return secretKey[:length], nil
|
||||
}
|
||||
|
||||
func MustGenerateSecretKey(length int) string {
|
||||
key, err := GenerateSecretKey(length)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return key
|
||||
}
|
Reference in New Issue
Block a user