feat: Add connection

This commit is contained in:
sujit
2024-10-20 23:24:58 +05:45
parent a06396da56
commit 35a79be4ad
9 changed files with 201 additions and 149 deletions

View File

@@ -7,7 +7,6 @@ import (
"log"
"net"
"strings"
"sync"
"time"
"github.com/oarkflow/mq/codec"
@@ -47,12 +46,12 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons
}
}
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg)
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(ctx, conn, msg)
}
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(conn)
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(ctx, conn)
}
func (c *Consumer) Close() error {
@@ -75,10 +74,10 @@ func (c *Consumer) SetKey(key string) {
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
if err := c.send(c.conn, msg); err != nil {
if err := c.send(ctx, c.conn, msg); err != nil {
return fmt.Errorf("error while trying to subscribe: %v", err)
}
return c.waitForAck(c.conn)
return c.waitForAck(ctx, c.conn)
}
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
@@ -122,7 +121,7 @@ func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err := c.send(conn, reply); err != nil {
if err := c.send(ctx, conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
}
}
@@ -172,7 +171,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
if result.Payload != nil || result.Error != nil {
bt, _ := json.Marshal(result)
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
if err := c.send(c.conn, reply); err != nil {
if err := c.send(ctx, c.conn, reply); err != nil {
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
}
}
@@ -182,7 +181,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
headers := HeadersWithConsumerID(ctx, c.id)
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
if sendErr := c.send(c.conn, reply); sendErr != nil {
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
}
}
@@ -209,7 +208,7 @@ func (c *Consumer) attemptConnect() error {
}
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(conn)
msg, err := c.receive(ctx, conn)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
return c.OnMessage(ctx, msg, conn)
@@ -235,24 +234,32 @@ func (c *Consumer) Consume(ctx context.Context) error {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
}
c.pool.Start(c.opts.numOfWorkers)
var wg sync.WaitGroup
wg.Add(1)
stopChan := make(chan struct{})
go func() {
defer wg.Done()
defer close(stopChan) // Signal completion when done
for {
if err := c.readMessage(ctx, c.conn); err != nil {
log.Println("Error reading message:", err)
break
select {
case <-ctx.Done():
log.Println("Context canceled, stopping message reading.")
return
default:
if err := c.readMessage(ctx, c.conn); err != nil {
log.Println("Error reading message:", err)
return // Exit the goroutine on error
}
}
}
}()
wg.Wait()
select {
case <-stopChan:
case <-ctx.Done():
log.Println("Context canceled, performing cleanup.")
}
return nil
}
func (c *Consumer) waitForAck(conn net.Conn) error {
msg, err := c.receive(conn)
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(ctx, conn)
if err != nil {
return err
}
@@ -286,7 +293,7 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(cmd, nil, c.queue, headers)
return c.send(c.conn, msg)
return c.send(ctx, c.conn, msg)
}
func (c *Consumer) Conn() net.Conn {