diff --git a/broker.go b/broker.go index 51110d5..28a26d1 100644 --- a/broker.go +++ b/broker.go @@ -8,10 +8,7 @@ import ( "net" "time" - "github.com/oarkflow/xid" "github.com/oarkflow/xsync" - - "github.com/oarkflow/mq/utils" ) type Handler func(context.Context, Task) Result @@ -39,15 +36,6 @@ type Task struct { Error error `json:"error"` } -type CMD int - -const ( - SUBSCRIBE CMD = iota + 1 - PUBLISH - REQUEST - STOP -) - type Command struct { ID string `json:"id"` Command CMD `json:"command"` @@ -55,7 +43,6 @@ type Command struct { MessageID string `json:"message_id"` Payload json.RawMessage `json:"payload,omitempty"` // Used for carrying the task payload Error string `json:"error,omitempty"` - Options map[string]any `json:"options"` } type Result struct { @@ -83,7 +70,7 @@ func (b *Broker) Send(ctx context.Context, cmd Command) error { return errors.New("invalid queue or not exists") } for client := range queue.conn { - err := utils.Write(ctx, client, cmd) + err := Write(ctx, client, cmd) if err != nil { return err } @@ -106,7 +93,7 @@ func (b *Broker) Start(ctx context.Context, addr string) error { fmt.Println("Error accepting connection:", err) continue } - go utils.ReadFromConn(ctx, conn, b.readMessage) + go ReadFromConn(ctx, conn, b.readMessage) } } @@ -116,12 +103,12 @@ func (b *Broker) Publish(ctx context.Context, message Task, queueName string) er return err } if len(queue.conn) == 0 { - queue.deferred.Set(xid.New().String(), &message) + queue.deferred.Set(NewID(), &message) fmt.Println("task deferred as no conn are connected", queueName) return nil } for client := range queue.conn { - err = utils.Write(ctx, client, message) + err = Write(ctx, client, message) if err != nil { return err } @@ -145,7 +132,7 @@ func (b *Broker) AddMessageToQueue(message *Task, queueName string) (*Queue, err return nil, fmt.Errorf("queue %s not found", queueName) } if message.ID == "" { - message.ID = xid.New().String() + message.ID = NewID() } if queueName != "" { message.CurrentQueue = queueName @@ -242,7 +229,7 @@ func (b *Broker) publish(ctx context.Context, conn net.Conn, msg Command) error Status: "success", Queue: msg.Queue, } - return utils.Write(ctx, conn, result) + return Write(ctx, conn, result) } return nil } @@ -265,7 +252,7 @@ func (b *Broker) request(ctx context.Context, conn net.Conn, msg Command) error Status: "success", Queue: msg.Queue, } - return utils.Write(ctx, conn, result) + return Write(ctx, conn, result) } return nil } diff --git a/constants.go b/constants.go new file mode 100644 index 0000000..9cdece4 --- /dev/null +++ b/constants.go @@ -0,0 +1,18 @@ +package mq + +type CMD int + +const ( + SUBSCRIBE CMD = iota + 1 + PUBLISH + REQUEST + STOP +) + +var ( + ConsumerKey = "Consumer-Key" + PublisherKey = "Publisher-Key" + ContentType = "Content-Type" + TypeJson = "application/json" + HeaderKey = "headers" +) diff --git a/consumer.go b/consumer.go index c916e3e..9325c70 100644 --- a/consumer.go +++ b/consumer.go @@ -10,10 +10,6 @@ import ( "slices" "sync" "time" - - "github.com/oarkflow/xid" - - "github.com/oarkflow/mq/utils" ) type Consumer struct { @@ -39,15 +35,16 @@ func (c *Consumer) Close() error { func (c *Consumer) subscribe(queue string) error { ctx := context.Background() + ctx = SetHeaders(ctx, map[string]string{ + ConsumerKey: c.id, + ContentType: TypeJson, + }) subscribe := Command{ Command: SUBSCRIBE, Queue: queue, - ID: xid.New().String(), - Options: map[string]any{ - "consumer_id": c.id, - }, + ID: NewID(), } - return utils.Write(ctx, c.conn, subscribe) + return Write(ctx, c.conn, subscribe) } func (c *Consumer) ProcessTask(ctx context.Context, msg Task) Result { @@ -81,7 +78,7 @@ func (c *Consumer) handleTaskMessage(ctx context.Context, msg Task) error { } func (c *Consumer) sendResult(ctx context.Context, response Result) error { - return utils.Write(ctx, c.conn, response) + return Write(ctx, c.conn, response) } func (c *Consumer) readMessage(ctx context.Context, message []byte) error { @@ -141,7 +138,7 @@ func (c *Consumer) Consume(ctx context.Context, queues ...string) error { wg.Add(1) go func() { defer wg.Done() - utils.ReadFromConn(ctx, c.conn, func(ctx context.Context, conn net.Conn, message []byte) error { + ReadFromConn(ctx, c.conn, func(ctx context.Context, conn net.Conn, message []byte) error { return c.readMessage(ctx, message) }) fmt.Println("Stopping consumer") diff --git a/utils/utils.go b/ctx.go similarity index 58% rename from utils/utils.go rename to ctx.go index eda8e64..37d1964 100644 --- a/utils/utils.go +++ b/ctx.go @@ -1,4 +1,4 @@ -package utils +package mq import ( "bufio" @@ -9,12 +9,13 @@ import ( "io" "net" "strings" + + "github.com/oarkflow/xid" ) type Message struct { - Headers map[string]string `json:"headers"` - Data json.RawMessage `json:"data"` - TriggerNode string `json:"triggerNode"` + Headers map[string]string `json:"headers"` + Data json.RawMessage `json:"data"` } func IsClosed(conn net.Conn) bool { @@ -27,32 +28,47 @@ func IsClosed(conn net.Conn) bool { return false } -func SetHeadersToContext(ctx context.Context, headers map[string]string) context.Context { - return context.WithValue(ctx, "headers", headers) +func SetHeaders(ctx context.Context, headers map[string]string) context.Context { + return context.WithValue(ctx, HeaderKey, headers) } -func GetHeadersFromContext(ctx context.Context) (map[string]string, bool) { - headers, ok := ctx.Value("headers").(map[string]string) +func GetHeaders(ctx context.Context) (map[string]string, bool) { + headers, ok := ctx.Value(HeaderKey).(map[string]string) return headers, ok } -func SetTriggerNodeToContext(ctx context.Context, triggerNode string) context.Context { - return context.WithValue(ctx, "triggerNode", triggerNode) +func GetContentType(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[ContentType] + return contentType, ok } -func GetTriggerNodeFromContext(ctx context.Context) (string, bool) { - headers, ok := ctx.Value("triggerNode").(string) - return headers, ok +func GetConsumerID(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[ConsumerKey] + return contentType, ok +} + +func GetPublisherID(ctx context.Context) (string, bool) { + headers, ok := GetHeaders(ctx) + if !ok { + return "", false + } + contentType, ok := headers[PublisherKey] + return contentType, ok } func Write(ctx context.Context, conn net.Conn, data any) error { msg := Message{Headers: make(map[string]string)} - if headers, ok := GetHeadersFromContext(ctx); ok { + if headers, ok := GetHeaders(ctx); ok { msg.Headers = headers } - if trigger, ok := GetTriggerNodeFromContext(ctx); ok { - msg.TriggerNode = trigger - } dataBytes, err := json.Marshal(data) if err != nil { return err @@ -90,8 +106,7 @@ func ReadFromConn(ctx context.Context, conn net.Conn, handler MessageHandler) { fmt.Println("Error unmarshalling message:", err) continue } - ctx = SetHeadersToContext(ctx, msg.Headers) - ctx = SetTriggerNodeToContext(ctx, msg.TriggerNode) + ctx = SetHeaders(ctx, msg.Headers) if handler != nil { err = handler(ctx, conn, msg.Data) if err != nil { @@ -101,3 +116,7 @@ func ReadFromConn(ctx context.Context, conn net.Conn, handler MessageHandler) { } } } + +func NewID() string { + return xid.New().String() +} diff --git a/publisher.go b/publisher.go index be39098..0fe67d4 100644 --- a/publisher.go +++ b/publisher.go @@ -4,10 +4,6 @@ import ( "context" "fmt" "net" - - "github.com/oarkflow/xid" - - "github.com/oarkflow/mq/utils" ) type Publisher struct { @@ -25,20 +21,20 @@ func (p *Publisher) Publish(ctx context.Context, queue string, task Task) error return fmt.Errorf("failed to connect to broker: %w", err) } defer conn.Close() - + ctx = SetHeaders(ctx, map[string]string{ + PublisherKey: p.id, + ContentType: TypeJson, + }) cmd := Command{ - ID: xid.New().String(), + ID: NewID(), Command: PUBLISH, Queue: queue, MessageID: task.ID, Payload: task.Payload, - Options: map[string]any{ - "publisher_id": p.id, - }, } // Fire and forget: No need to wait for response - return utils.Write(ctx, conn, cmd) + return Write(ctx, conn, cmd) } func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Result, error) { @@ -47,22 +43,23 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) (Resul return Result{}, fmt.Errorf("failed to connect to broker: %w", err) } defer conn.Close() + ctx = SetHeaders(ctx, map[string]string{ + PublisherKey: p.id, + ContentType: TypeJson, + }) cmd := Command{ - ID: xid.New().String(), + ID: NewID(), Command: REQUEST, Queue: queue, MessageID: task.ID, Payload: task.Payload, - Options: map[string]any{ - "publisher_id": p.id, - }, } var result Result - err = utils.Write(ctx, conn, cmd) + err = Write(ctx, conn, cmd) if err != nil { return result, err } - utils.ReadFromConn(ctx, conn, func(ctx context.Context, conn net.Conn, bytes []byte) error { + ReadFromConn(ctx, conn, func(ctx context.Context, conn net.Conn, bytes []byte) error { fmt.Println(string(bytes)) return nil })