mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-03 23:26:28 +08:00
feat: Add connection
This commit is contained in:
61
broker.go
61
broker.go
@@ -75,18 +75,15 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
|
|||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
b.consumers.ForEach(func(consumerID string, con *consumer) bool {
|
b.consumers.ForEach(func(consumerID string, con *consumer) bool {
|
||||||
if con.conn.RemoteAddr().String() == conn.RemoteAddr().String() &&
|
if utils.ConnectionsEqual(conn, con.conn) {
|
||||||
con.conn.LocalAddr().String() == conn.LocalAddr().String() {
|
con.conn.Close()
|
||||||
if c, exists := b.consumers.Get(consumerID); exists {
|
b.consumers.Del(consumerID)
|
||||||
c.conn.Close()
|
|
||||||
b.consumers.Del(consumerID)
|
|
||||||
}
|
|
||||||
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
||||||
|
queue.consumers.Del(consumerID)
|
||||||
if _, ok := queue.consumers.Get(consumerID); ok {
|
if _, ok := queue.consumers.Get(consumerID); ok {
|
||||||
if b.opts.consumerOnClose != nil {
|
if b.opts.consumerOnClose != nil {
|
||||||
b.opts.consumerOnClose(ctx, queue.name, consumerID)
|
b.opts.consumerOnClose(ctx, queue.name, consumerID)
|
||||||
}
|
}
|
||||||
queue.consumers.Del(consumerID)
|
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
@@ -189,7 +186,7 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message)
|
|||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err := b.send(con.conn, msg)
|
err := b.send(ctx, con.conn, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -202,7 +199,7 @@ func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
|
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
|
||||||
b.broadcastToConsumers(ctx, msg)
|
b.broadcastToConsumers(msg)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,10 +209,10 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
|
|||||||
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
||||||
|
|
||||||
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||||||
if err := b.send(conn, ack); err != nil {
|
if err := b.send(ctx, conn, ack); err != nil {
|
||||||
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||||||
}
|
}
|
||||||
b.broadcastToConsumers(ctx, msg)
|
b.broadcastToConsumers(msg)
|
||||||
go func() {
|
go func() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -227,7 +224,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
|
|||||||
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||||
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
|
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
|
||||||
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||||||
if err := b.send(conn, ack); err != nil {
|
if err := b.send(ctx, conn, ack); err != nil {
|
||||||
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||||
}
|
}
|
||||||
if b.opts.consumerOnSubscribe != nil {
|
if b.opts.consumerOnSubscribe != nil {
|
||||||
@@ -284,23 +281,23 @@ func (b *Broker) Start(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
|
func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||||||
return codec.SendMessage(conn, msg)
|
return codec.SendMessage(ctx, conn, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
|
func (b *Broker) receive(ctx context.Context, c net.Conn) (*codec.Message, error) {
|
||||||
return codec.ReadMessage(c)
|
return codec.ReadMessage(ctx, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
func (b *Broker) broadcastToConsumers(msg *codec.Message) {
|
||||||
if queue, ok := b.queues.Get(msg.Queue); ok {
|
if queue, ok := b.queues.Get(msg.Queue); ok {
|
||||||
task := &QueuedTask{Message: msg, RetryCount: 0}
|
task := &QueuedTask{Message: msg, RetryCount: 0}
|
||||||
queue.tasks <- task
|
queue.tasks <- task
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
|
func (b *Broker) waitForConsumerAck(ctx context.Context, conn net.Conn) error {
|
||||||
msg, err := b.receive(conn)
|
msg, err := b.receive(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -360,12 +357,12 @@ func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
|
func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
|
||||||
fn := func(queue *Queue) {
|
fn := func(queue *Queue) {
|
||||||
con, ok := queue.consumers.Get(consumerID)
|
con, ok := queue.consumers.Get(consumerID)
|
||||||
if ok {
|
if ok {
|
||||||
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||||||
err := b.send(con.conn, ack)
|
err := b.send(ctx, con.conn, ack)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
con.state = state
|
con.state = state
|
||||||
}
|
}
|
||||||
@@ -385,20 +382,20 @@ func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, cons
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) PauseConsumer(consumerID string, queues ...string) {
|
func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||||
b.handleConsumer(consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
|
b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) ResumeConsumer(consumerID string, queues ...string) {
|
func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||||
b.handleConsumer(consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
|
b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) StopConsumer(consumerID string, queues ...string) {
|
func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||||
b.handleConsumer(consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
|
b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||||
msg, err := b.receive(c)
|
msg, err := b.receive(ctx, c)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
ctx = SetHeaders(ctx, msg.Headers)
|
ctx = SetHeaders(ctx, msg.Headers)
|
||||||
b.OnMessage(ctx, msg, c)
|
b.OnMessage(ctx, msg, c)
|
||||||
@@ -412,12 +409,12 @@ func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) dispatchWorker(queue *Queue) {
|
func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) {
|
||||||
delay := b.opts.initialDelay
|
delay := b.opts.initialDelay
|
||||||
for task := range queue.tasks {
|
for task := range queue.tasks {
|
||||||
success := false
|
success := false
|
||||||
for !success && task.RetryCount <= b.opts.maxRetries {
|
for !success && task.RetryCount <= b.opts.maxRetries {
|
||||||
if b.dispatchTaskToConsumer(queue, task) {
|
if b.dispatchTaskToConsumer(ctx, queue, task) {
|
||||||
success = true
|
success = true
|
||||||
} else {
|
} else {
|
||||||
task.RetryCount++
|
task.RetryCount++
|
||||||
@@ -440,7 +437,7 @@ func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task *QueuedTask) bool {
|
||||||
var consumerFound bool
|
var consumerFound bool
|
||||||
var err error
|
var err error
|
||||||
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
||||||
@@ -448,7 +445,7 @@ func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
|||||||
err = fmt.Errorf("consumer %s is not active", con.id)
|
err = fmt.Errorf("consumer %s is not active", con.id)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if err := b.send(con.conn, task.Message); err == nil {
|
if err := b.send(ctx, con.conn, task.Message); err == nil {
|
||||||
consumerFound = true
|
consumerFound = true
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
package codec
|
package codec
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -12,11 +14,13 @@ type Message struct {
|
|||||||
Headers map[string]string `msgpack:"h"`
|
Headers map[string]string `msgpack:"h"`
|
||||||
Queue string `msgpack:"q"`
|
Queue string `msgpack:"q"`
|
||||||
Payload []byte `msgpack:"p"`
|
Payload []byte `msgpack:"p"`
|
||||||
m sync.RWMutex
|
Command consts.CMD `msgpack:"c"`
|
||||||
Command consts.CMD `msgpack:"c"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message {
|
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message {
|
||||||
|
if headers == nil {
|
||||||
|
headers = make(map[string]string)
|
||||||
|
}
|
||||||
return &Message{
|
return &Message{
|
||||||
Headers: headers,
|
Headers: headers,
|
||||||
Queue: queue,
|
Queue: queue,
|
||||||
@@ -26,8 +30,6 @@ func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) Serialize() ([]byte, error) {
|
func (m *Message) Serialize() ([]byte, error) {
|
||||||
m.m.RLock()
|
|
||||||
defer m.m.RUnlock()
|
|
||||||
data, err := Marshal(m)
|
data, err := Marshal(m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -40,37 +42,66 @@ func Deserialize(data []byte) (*Message, error) {
|
|||||||
if err := Unmarshal(data, &msg); err != nil {
|
if err := Unmarshal(data, &msg); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SendMessage(conn net.Conn, msg *Message) error {
|
var byteBufferPool = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return make([]byte, 0, 4096)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
|
||||||
data, err := msg.Serialize()
|
data, err := msg.Serialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
length := make([]byte, 4)
|
totalLength := 4 + len(data)
|
||||||
binary.BigEndian.PutUint32(length, uint32(len(data)))
|
buffer := byteBufferPool.Get().([]byte)
|
||||||
|
if cap(buffer) < totalLength {
|
||||||
|
buffer = make([]byte, totalLength)
|
||||||
|
} else {
|
||||||
|
buffer = buffer[:totalLength]
|
||||||
|
}
|
||||||
|
defer byteBufferPool.Put(buffer)
|
||||||
|
|
||||||
if _, err := conn.Write(length); err != nil {
|
binary.BigEndian.PutUint32(buffer[:4], uint32(len(data)))
|
||||||
return err
|
copy(buffer[4:], data)
|
||||||
|
|
||||||
|
writer := bufio.NewWriter(conn)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
if _, err := writer.Write(buffer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if _, err := conn.Write(data); err != nil {
|
|
||||||
return err
|
return writer.Flush()
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadMessage(conn net.Conn) (*Message, error) {
|
func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
|
||||||
lengthBytes := make([]byte, 4)
|
lengthBytes := make([]byte, 4)
|
||||||
if _, err := conn.Read(lengthBytes); err != nil {
|
if _, err := conn.Read(lengthBytes); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
length := binary.BigEndian.Uint32(lengthBytes)
|
length := binary.BigEndian.Uint32(lengthBytes)
|
||||||
data := make([]byte, length)
|
data := byteBufferPool.Get().([]byte)[:length]
|
||||||
if _, err := conn.Read(data); err != nil {
|
defer byteBufferPool.Put(data)
|
||||||
return nil, err
|
totalRead := 0
|
||||||
|
for totalRead < int(length) {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
default:
|
||||||
|
n, err := conn.Read(data[totalRead:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
totalRead += n
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Deserialize(data)
|
return Deserialize(data[:length])
|
||||||
}
|
}
|
||||||
|
51
consumer.go
51
consumer.go
@@ -7,7 +7,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oarkflow/mq/codec"
|
"github.com/oarkflow/mq/codec"
|
||||||
@@ -47,12 +46,12 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
|
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||||||
return codec.SendMessage(conn, msg)
|
return codec.SendMessage(ctx, conn, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
|
||||||
return codec.ReadMessage(conn)
|
return codec.ReadMessage(ctx, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) Close() error {
|
func (c *Consumer) Close() error {
|
||||||
@@ -75,10 +74,10 @@ func (c *Consumer) SetKey(key string) {
|
|||||||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||||
headers := HeadersWithConsumerID(ctx, c.id)
|
headers := HeadersWithConsumerID(ctx, c.id)
|
||||||
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
||||||
if err := c.send(c.conn, msg); err != nil {
|
if err := c.send(ctx, c.conn, msg); err != nil {
|
||||||
return fmt.Errorf("error while trying to subscribe: %v", err)
|
return fmt.Errorf("error while trying to subscribe: %v", err)
|
||||||
}
|
}
|
||||||
return c.waitForAck(c.conn)
|
return c.waitForAck(ctx, c.conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
|
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
|
||||||
@@ -122,7 +121,7 @@ func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn
|
|||||||
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
|
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
|
||||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||||
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||||
if err := c.send(conn, reply); err != nil {
|
if err := c.send(ctx, conn, reply); err != nil {
|
||||||
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -172,7 +171,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
|||||||
if result.Payload != nil || result.Error != nil {
|
if result.Payload != nil || result.Error != nil {
|
||||||
bt, _ := json.Marshal(result)
|
bt, _ := json.Marshal(result)
|
||||||
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
||||||
if err := c.send(c.conn, reply); err != nil {
|
if err := c.send(ctx, c.conn, reply); err != nil {
|
||||||
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
|
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -182,7 +181,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
|||||||
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
|
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
|
||||||
headers := HeadersWithConsumerID(ctx, c.id)
|
headers := HeadersWithConsumerID(ctx, c.id)
|
||||||
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
||||||
if sendErr := c.send(c.conn, reply); sendErr != nil {
|
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
|
||||||
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
|
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -209,7 +208,7 @@ func (c *Consumer) attemptConnect() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||||||
msg, err := c.receive(conn)
|
msg, err := c.receive(ctx, conn)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
ctx = SetHeaders(ctx, msg.Headers)
|
ctx = SetHeaders(ctx, msg.Headers)
|
||||||
return c.OnMessage(ctx, msg, conn)
|
return c.OnMessage(ctx, msg, conn)
|
||||||
@@ -235,24 +234,32 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
|||||||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||||||
}
|
}
|
||||||
c.pool.Start(c.opts.numOfWorkers)
|
c.pool.Start(c.opts.numOfWorkers)
|
||||||
var wg sync.WaitGroup
|
stopChan := make(chan struct{})
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer close(stopChan) // Signal completion when done
|
||||||
for {
|
for {
|
||||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
select {
|
||||||
log.Println("Error reading message:", err)
|
case <-ctx.Done():
|
||||||
break
|
log.Println("Context canceled, stopping message reading.")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||||
|
log.Println("Error reading message:", err)
|
||||||
|
return // Exit the goroutine on error
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
select {
|
||||||
wg.Wait()
|
case <-stopChan:
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Println("Context canceled, performing cleanup.")
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) waitForAck(conn net.Conn) error {
|
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
|
||||||
msg, err := c.receive(conn)
|
msg, err := c.receive(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -286,7 +293,7 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu
|
|||||||
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
||||||
headers := HeadersWithConsumerID(ctx, c.id)
|
headers := HeadersWithConsumerID(ctx, c.id)
|
||||||
msg := codec.NewMessage(cmd, nil, c.queue, headers)
|
msg := codec.NewMessage(cmd, nil, c.queue, headers)
|
||||||
return c.send(c.conn, msg)
|
return c.send(ctx, c.conn, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Consumer) Conn() net.Conn {
|
func (c *Consumer) Conn() net.Conn {
|
||||||
|
@@ -9,6 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
consumer1 := mq.NewConsumer("F", "queue1", tasks.Node6{}.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
|
n := &tasks.Node6{}
|
||||||
|
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
|
||||||
consumer1.Consume(context.Background())
|
consumer1.Consume(context.Background())
|
||||||
}
|
}
|
||||||
|
@@ -3,7 +3,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/oarkflow/mq"
|
"github.com/oarkflow/mq"
|
||||||
)
|
)
|
||||||
@@ -25,7 +24,6 @@ func main() {
|
|||||||
Payload: payload,
|
Payload: payload,
|
||||||
}
|
}
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
err := publisher.Publish(context.Background(), task, "queue1")
|
err := publisher.Publish(context.Background(), task, "queue1")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
14
publisher.go
14
publisher.go
@@ -37,15 +37,15 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
msg := codec.NewMessage(command, payload, queue, headers)
|
msg := codec.NewMessage(command, payload, queue, headers)
|
||||||
if err := codec.SendMessage(conn, msg); err != nil {
|
if err := codec.SendMessage(ctx, conn, msg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.waitForAck(conn)
|
return p.waitForAck(ctx, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Publisher) waitForAck(conn net.Conn) error {
|
func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error {
|
||||||
msg, err := codec.ReadMessage(conn)
|
msg, err := codec.ReadMessage(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -57,8 +57,8 @@ func (p *Publisher) waitForAck(conn net.Conn) error {
|
|||||||
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Publisher) waitForResponse(conn net.Conn) Result {
|
func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result {
|
||||||
msg, err := codec.ReadMessage(conn)
|
msg, err := codec.ReadMessage(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Result{Error: err}
|
return Result{Error: err}
|
||||||
}
|
}
|
||||||
@@ -103,7 +103,7 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result
|
|||||||
resultCh := make(chan Result)
|
resultCh := make(chan Result)
|
||||||
go func() {
|
go func() {
|
||||||
defer close(resultCh)
|
defer close(resultCh)
|
||||||
resultCh <- p.waitForResponse(conn)
|
resultCh <- p.waitForResponse(ctx, conn)
|
||||||
}()
|
}()
|
||||||
finalResult := <-resultCh
|
finalResult := <-resultCh
|
||||||
return finalResult
|
return finalResult
|
||||||
|
6
queue.go
6
queue.go
@@ -36,9 +36,9 @@ func (b *Broker) NewQueue(name string) *Queue {
|
|||||||
consumers: memory.New[string, *consumer](),
|
consumers: memory.New[string, *consumer](),
|
||||||
}
|
}
|
||||||
b.deadLetter.Set(name, dlq)
|
b.deadLetter.Set(name, dlq)
|
||||||
|
ctx := context.Background()
|
||||||
go b.dispatchWorker(q)
|
go b.dispatchWorker(ctx, q)
|
||||||
go b.dispatchWorker(dlq)
|
go b.dispatchWorker(ctx, dlq)
|
||||||
return q
|
return q
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -8,94 +8,97 @@ import (
|
|||||||
|
|
||||||
var _ storage.IMap[string, any] = (*Map[string, any])(nil)
|
var _ storage.IMap[string, any] = (*Map[string, any])(nil)
|
||||||
|
|
||||||
|
// Map is a thread-safe map using sync.Map with generics
|
||||||
type Map[K comparable, V any] struct {
|
type Map[K comparable, V any] struct {
|
||||||
data map[K]V
|
m sync.Map
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New creates a new Map
|
||||||
func New[K comparable, V any]() *Map[K, V] {
|
func New[K comparable, V any]() *Map[K, V] {
|
||||||
return &Map[K, V]{
|
return &Map[K, V]{}
|
||||||
data: make(map[K]V),
|
}
|
||||||
|
|
||||||
|
// Get retrieves the value for a given key
|
||||||
|
func (g *Map[K, V]) Get(key K) (V, bool) {
|
||||||
|
val, ok := g.m.Load(key)
|
||||||
|
if !ok {
|
||||||
|
var zeroValue V
|
||||||
|
return zeroValue, false
|
||||||
}
|
}
|
||||||
|
return val.(V), true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
// Set adds a key-value pair to the map
|
||||||
m.mu.RLock()
|
func (g *Map[K, V]) Set(key K, value V) {
|
||||||
defer m.mu.RUnlock()
|
g.m.Store(key, value)
|
||||||
val, exists := m.data[key]
|
|
||||||
return val, exists
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) Set(key K, value V) {
|
// Del removes a key-value pair from the map
|
||||||
m.mu.Lock()
|
func (g *Map[K, V]) Del(key K) {
|
||||||
defer m.mu.Unlock()
|
g.m.Delete(key)
|
||||||
m.data[key] = value
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) Del(key K) {
|
// ForEach iterates over the map
|
||||||
m.mu.Lock()
|
func (g *Map[K, V]) ForEach(fn func(K, V) bool) {
|
||||||
defer m.mu.Unlock()
|
g.m.Range(func(k, v any) bool {
|
||||||
delete(m.data, key)
|
return fn(k.(K), v.(V))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) ForEach(f func(K, V) bool) {
|
// Clear removes all key-value pairs from the map
|
||||||
m.mu.RLock()
|
func (g *Map[K, V]) Clear() {
|
||||||
defer m.mu.RUnlock()
|
g.ForEach(func(k K, v V) bool {
|
||||||
for k, v := range m.data {
|
g.Del(k)
|
||||||
if !f(k, v) {
|
return true
|
||||||
break
|
})
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) Clear() {
|
// Size returns the number of key-value pairs in the map
|
||||||
m.mu.Lock()
|
func (g *Map[K, V]) Size() int {
|
||||||
defer m.mu.Unlock()
|
count := 0
|
||||||
m.data = make(map[K]V)
|
g.ForEach(func(_ K, _ V) bool {
|
||||||
|
count++
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return count
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) Size() int {
|
// Keys returns a slice of all keys in the map
|
||||||
m.mu.RLock()
|
func (g *Map[K, V]) Keys() []K {
|
||||||
defer m.mu.RUnlock()
|
keys := []K{}
|
||||||
return len(m.data)
|
g.ForEach(func(k K, _ V) bool {
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Map[K, V]) Keys() []K {
|
|
||||||
m.mu.RLock()
|
|
||||||
defer m.mu.RUnlock()
|
|
||||||
keys := make([]K, 0, len(m.data))
|
|
||||||
for k := range m.data {
|
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) Values() []V {
|
// Values returns a slice of all values in the map
|
||||||
m.mu.RLock()
|
func (g *Map[K, V]) Values() []V {
|
||||||
defer m.mu.RUnlock()
|
values := []V{}
|
||||||
values := make([]V, 0, len(m.data))
|
g.ForEach(func(_ K, v V) bool {
|
||||||
for _, v := range m.data {
|
|
||||||
values = append(values, v)
|
values = append(values, v)
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
return values
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) AsMap() map[K]V {
|
// AsMap returns a regular map containing all key-value pairs
|
||||||
m.mu.RLock()
|
func (g *Map[K, V]) AsMap() map[K]V {
|
||||||
defer m.mu.RUnlock()
|
result := make(map[K]V)
|
||||||
copiedMap := make(map[K]V, len(m.data))
|
g.ForEach(func(k K, v V) bool {
|
||||||
for k, v := range m.data {
|
result[k] = v
|
||||||
copiedMap[k] = v
|
return true
|
||||||
}
|
})
|
||||||
return copiedMap
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Map[K, V]) Clone() storage.IMap[K, V] {
|
// Clone creates a shallow copy of the map
|
||||||
m.mu.RLock()
|
func (g *Map[K, V]) Clone() storage.IMap[K, V] {
|
||||||
defer m.mu.RUnlock()
|
clone := New[K, V]()
|
||||||
clonedMap := New[K, V]()
|
g.ForEach(func(k K, v V) bool {
|
||||||
for k, v := range m.data {
|
clone.Set(k, v)
|
||||||
clonedMap.Set(k, v)
|
return true
|
||||||
}
|
})
|
||||||
return clonedMap
|
return clone
|
||||||
}
|
}
|
||||||
|
15
utils/conn.go
Normal file
15
utils/conn.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func localAddr(c net.Conn) string { return c.LocalAddr().String() }
|
||||||
|
func remoteAddr(c net.Conn) string { return c.RemoteAddr().String() }
|
||||||
|
|
||||||
|
func ConnectionsEqual(c1, c2 net.Conn) bool {
|
||||||
|
if c1 == nil || c2 == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2)
|
||||||
|
}
|
Reference in New Issue
Block a user