mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-26 20:11:16 +08:00
feat: Add connection
This commit is contained in:
51
consumer.go
51
consumer.go
@@ -7,7 +7,6 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
@@ -47,12 +46,12 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg)
|
||||
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(ctx, conn, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(conn)
|
||||
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(ctx, conn)
|
||||
}
|
||||
|
||||
func (c *Consumer) Close() error {
|
||||
@@ -75,10 +74,10 @@ func (c *Consumer) SetKey(key string) {
|
||||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
||||
if err := c.send(c.conn, msg); err != nil {
|
||||
if err := c.send(ctx, c.conn, msg); err != nil {
|
||||
return fmt.Errorf("error while trying to subscribe: %v", err)
|
||||
}
|
||||
return c.waitForAck(c.conn)
|
||||
return c.waitForAck(ctx, c.conn)
|
||||
}
|
||||
|
||||
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
|
||||
@@ -122,7 +121,7 @@ func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn
|
||||
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||
if err := c.send(conn, reply); err != nil {
|
||||
if err := c.send(ctx, conn, reply); err != nil {
|
||||
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||||
}
|
||||
}
|
||||
@@ -172,7 +171,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
||||
if result.Payload != nil || result.Error != nil {
|
||||
bt, _ := json.Marshal(result)
|
||||
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
||||
if err := c.send(c.conn, reply); err != nil {
|
||||
if err := c.send(ctx, c.conn, reply); err != nil {
|
||||
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -182,7 +181,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
||||
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
||||
if sendErr := c.send(c.conn, reply); sendErr != nil {
|
||||
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
|
||||
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
|
||||
}
|
||||
}
|
||||
@@ -209,7 +208,7 @@ func (c *Consumer) attemptConnect() error {
|
||||
}
|
||||
|
||||
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||||
msg, err := c.receive(conn)
|
||||
msg, err := c.receive(ctx, conn)
|
||||
if err == nil {
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
return c.OnMessage(ctx, msg, conn)
|
||||
@@ -235,24 +234,32 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||||
}
|
||||
c.pool.Start(c.opts.numOfWorkers)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
stopChan := make(chan struct{})
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer close(stopChan) // Signal completion when done
|
||||
for {
|
||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||
log.Println("Error reading message:", err)
|
||||
break
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("Context canceled, stopping message reading.")
|
||||
return
|
||||
default:
|
||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||
log.Println("Error reading message:", err)
|
||||
return // Exit the goroutine on error
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
select {
|
||||
case <-stopChan:
|
||||
case <-ctx.Done():
|
||||
log.Println("Context canceled, performing cleanup.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) waitForAck(conn net.Conn) error {
|
||||
msg, err := c.receive(conn)
|
||||
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
|
||||
msg, err := c.receive(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -286,7 +293,7 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu
|
||||
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
msg := codec.NewMessage(cmd, nil, c.queue, headers)
|
||||
return c.send(c.conn, msg)
|
||||
return c.send(ctx, c.conn, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) Conn() net.Conn {
|
||||
|
Reference in New Issue
Block a user