feat: Add connection

This commit is contained in:
sujit
2024-10-20 23:24:58 +05:45
parent a06396da56
commit 35a79be4ad
9 changed files with 201 additions and 149 deletions

View File

@@ -75,18 +75,15 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
})
} else {
b.consumers.ForEach(func(consumerID string, con *consumer) bool {
if con.conn.RemoteAddr().String() == conn.RemoteAddr().String() &&
con.conn.LocalAddr().String() == conn.LocalAddr().String() {
if c, exists := b.consumers.Get(consumerID); exists {
c.conn.Close()
b.consumers.Del(consumerID)
}
if utils.ConnectionsEqual(conn, con.conn) {
con.conn.Close()
b.consumers.Del(consumerID)
b.queues.ForEach(func(_ string, queue *Queue) bool {
queue.consumers.Del(consumerID)
if _, ok := queue.consumers.Get(consumerID); ok {
if b.opts.consumerOnClose != nil {
b.opts.consumerOnClose(ctx, queue.name, consumerID)
}
queue.consumers.Del(consumerID)
}
return true
})
@@ -189,7 +186,7 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message)
if !ok {
return
}
err := b.send(con.conn, msg)
err := b.send(ctx, con.conn, msg)
if err != nil {
panic(err)
}
@@ -202,7 +199,7 @@ func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error {
return err
}
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
b.broadcastToConsumers(ctx, msg)
b.broadcastToConsumers(msg)
return nil
}
@@ -212,10 +209,10 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil {
if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
}
b.broadcastToConsumers(ctx, msg)
b.broadcastToConsumers(msg)
go func() {
select {
case <-ctx.Done():
@@ -227,7 +224,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil {
if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
}
if b.opts.consumerOnSubscribe != nil {
@@ -284,23 +281,23 @@ func (b *Broker) Start(ctx context.Context) error {
}
}
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg)
func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(ctx, conn, msg)
}
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
return codec.ReadMessage(c)
func (b *Broker) receive(ctx context.Context, c net.Conn) (*codec.Message, error) {
return codec.ReadMessage(ctx, c)
}
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
func (b *Broker) broadcastToConsumers(msg *codec.Message) {
if queue, ok := b.queues.Get(msg.Queue); ok {
task := &QueuedTask{Message: msg, RetryCount: 0}
queue.tasks <- task
}
}
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
msg, err := b.receive(conn)
func (b *Broker) waitForConsumerAck(ctx context.Context, conn net.Conn) error {
msg, err := b.receive(ctx, conn)
if err != nil {
return err
}
@@ -360,12 +357,12 @@ func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
})
}
func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
fn := func(queue *Queue) {
con, ok := queue.consumers.Get(consumerID)
if ok {
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
err := b.send(con.conn, ack)
err := b.send(ctx, con.conn, ack)
if err == nil {
con.state = state
}
@@ -385,20 +382,20 @@ func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, cons
})
}
func (b *Broker) PauseConsumer(consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
}
func (b *Broker) ResumeConsumer(consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
}
func (b *Broker) StopConsumer(consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
}
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
msg, err := b.receive(c)
msg, err := b.receive(ctx, c)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
b.OnMessage(ctx, msg, c)
@@ -412,12 +409,12 @@ func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
return err
}
func (b *Broker) dispatchWorker(queue *Queue) {
func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) {
delay := b.opts.initialDelay
for task := range queue.tasks {
success := false
for !success && task.RetryCount <= b.opts.maxRetries {
if b.dispatchTaskToConsumer(queue, task) {
if b.dispatchTaskToConsumer(ctx, queue, task) {
success = true
} else {
task.RetryCount++
@@ -440,7 +437,7 @@ func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
}
}
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task *QueuedTask) bool {
var consumerFound bool
var err error
queue.consumers.ForEach(func(_ string, con *consumer) bool {
@@ -448,7 +445,7 @@ func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
err = fmt.Errorf("consumer %s is not active", con.id)
return true
}
if err := b.send(con.conn, task.Message); err == nil {
if err := b.send(ctx, con.conn, task.Message); err == nil {
consumerFound = true
return false
}

View File

@@ -1,6 +1,8 @@
package codec
import (
"bufio"
"context"
"encoding/binary"
"net"
"sync"
@@ -12,11 +14,13 @@ type Message struct {
Headers map[string]string `msgpack:"h"`
Queue string `msgpack:"q"`
Payload []byte `msgpack:"p"`
m sync.RWMutex
Command consts.CMD `msgpack:"c"`
Command consts.CMD `msgpack:"c"`
}
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message {
if headers == nil {
headers = make(map[string]string)
}
return &Message{
Headers: headers,
Queue: queue,
@@ -26,8 +30,6 @@ func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string
}
func (m *Message) Serialize() ([]byte, error) {
m.m.RLock()
defer m.m.RUnlock()
data, err := Marshal(m)
if err != nil {
return nil, err
@@ -40,37 +42,66 @@ func Deserialize(data []byte) (*Message, error) {
if err := Unmarshal(data, &msg); err != nil {
return nil, err
}
return &msg, nil
}
func SendMessage(conn net.Conn, msg *Message) error {
var byteBufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 0, 4096)
},
}
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
data, err := msg.Serialize()
if err != nil {
return err
}
length := make([]byte, 4)
binary.BigEndian.PutUint32(length, uint32(len(data)))
totalLength := 4 + len(data)
buffer := byteBufferPool.Get().([]byte)
if cap(buffer) < totalLength {
buffer = make([]byte, totalLength)
} else {
buffer = buffer[:totalLength]
}
defer byteBufferPool.Put(buffer)
if _, err := conn.Write(length); err != nil {
return err
binary.BigEndian.PutUint32(buffer[:4], uint32(len(data)))
copy(buffer[4:], data)
writer := bufio.NewWriter(conn)
select {
case <-ctx.Done():
return ctx.Err()
default:
if _, err := writer.Write(buffer); err != nil {
return err
}
}
if _, err := conn.Write(data); err != nil {
return err
}
return nil
return writer.Flush()
}
func ReadMessage(conn net.Conn) (*Message, error) {
func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
lengthBytes := make([]byte, 4)
if _, err := conn.Read(lengthBytes); err != nil {
return nil, err
}
length := binary.BigEndian.Uint32(lengthBytes)
data := make([]byte, length)
if _, err := conn.Read(data); err != nil {
return nil, err
data := byteBufferPool.Get().([]byte)[:length]
defer byteBufferPool.Put(data)
totalRead := 0
for totalRead < int(length) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
n, err := conn.Read(data[totalRead:])
if err != nil {
return nil, err
}
totalRead += n
}
}
return Deserialize(data)
return Deserialize(data[:length])
}

View File

@@ -7,7 +7,6 @@ import (
"log"
"net"
"strings"
"sync"
"time"
"github.com/oarkflow/mq/codec"
@@ -47,12 +46,12 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons
}
}
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg)
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(ctx, conn, msg)
}
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(conn)
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(ctx, conn)
}
func (c *Consumer) Close() error {
@@ -75,10 +74,10 @@ func (c *Consumer) SetKey(key string) {
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
if err := c.send(c.conn, msg); err != nil {
if err := c.send(ctx, c.conn, msg); err != nil {
return fmt.Errorf("error while trying to subscribe: %v", err)
}
return c.waitForAck(c.conn)
return c.waitForAck(ctx, c.conn)
}
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
@@ -122,7 +121,7 @@ func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err := c.send(conn, reply); err != nil {
if err := c.send(ctx, conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
}
}
@@ -172,7 +171,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
if result.Payload != nil || result.Error != nil {
bt, _ := json.Marshal(result)
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
if err := c.send(c.conn, reply); err != nil {
if err := c.send(ctx, c.conn, reply); err != nil {
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
}
}
@@ -182,7 +181,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
headers := HeadersWithConsumerID(ctx, c.id)
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
if sendErr := c.send(c.conn, reply); sendErr != nil {
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
}
}
@@ -209,7 +208,7 @@ func (c *Consumer) attemptConnect() error {
}
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(conn)
msg, err := c.receive(ctx, conn)
if err == nil {
ctx = SetHeaders(ctx, msg.Headers)
return c.OnMessage(ctx, msg, conn)
@@ -235,24 +234,32 @@ func (c *Consumer) Consume(ctx context.Context) error {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
}
c.pool.Start(c.opts.numOfWorkers)
var wg sync.WaitGroup
wg.Add(1)
stopChan := make(chan struct{})
go func() {
defer wg.Done()
defer close(stopChan) // Signal completion when done
for {
if err := c.readMessage(ctx, c.conn); err != nil {
log.Println("Error reading message:", err)
break
select {
case <-ctx.Done():
log.Println("Context canceled, stopping message reading.")
return
default:
if err := c.readMessage(ctx, c.conn); err != nil {
log.Println("Error reading message:", err)
return // Exit the goroutine on error
}
}
}
}()
wg.Wait()
select {
case <-stopChan:
case <-ctx.Done():
log.Println("Context canceled, performing cleanup.")
}
return nil
}
func (c *Consumer) waitForAck(conn net.Conn) error {
msg, err := c.receive(conn)
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(ctx, conn)
if err != nil {
return err
}
@@ -286,7 +293,7 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(cmd, nil, c.queue, headers)
return c.send(c.conn, msg)
return c.send(ctx, c.conn, msg)
}
func (c *Consumer) Conn() net.Conn {

View File

@@ -9,6 +9,7 @@ import (
)
func main() {
consumer1 := mq.NewConsumer("F", "queue1", tasks.Node6{}.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
n := &tasks.Node6{}
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
consumer1.Consume(context.Background())
}

View File

@@ -3,7 +3,6 @@ package main
import (
"context"
"fmt"
"time"
"github.com/oarkflow/mq"
)
@@ -25,7 +24,6 @@ func main() {
Payload: payload,
}
for i := 0; i < 100; i++ {
time.Sleep(500 * time.Millisecond)
err := publisher.Publish(context.Background(), task, "queue1")
if err != nil {
panic(err)

View File

@@ -37,15 +37,15 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.
return err
}
msg := codec.NewMessage(command, payload, queue, headers)
if err := codec.SendMessage(conn, msg); err != nil {
if err := codec.SendMessage(ctx, conn, msg); err != nil {
return err
}
return p.waitForAck(conn)
return p.waitForAck(ctx, conn)
}
func (p *Publisher) waitForAck(conn net.Conn) error {
msg, err := codec.ReadMessage(conn)
func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error {
msg, err := codec.ReadMessage(ctx, conn)
if err != nil {
return err
}
@@ -57,8 +57,8 @@ func (p *Publisher) waitForAck(conn net.Conn) error {
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
}
func (p *Publisher) waitForResponse(conn net.Conn) Result {
msg, err := codec.ReadMessage(conn)
func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result {
msg, err := codec.ReadMessage(ctx, conn)
if err != nil {
return Result{Error: err}
}
@@ -103,7 +103,7 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result
resultCh := make(chan Result)
go func() {
defer close(resultCh)
resultCh <- p.waitForResponse(conn)
resultCh <- p.waitForResponse(ctx, conn)
}()
finalResult := <-resultCh
return finalResult

View File

@@ -36,9 +36,9 @@ func (b *Broker) NewQueue(name string) *Queue {
consumers: memory.New[string, *consumer](),
}
b.deadLetter.Set(name, dlq)
go b.dispatchWorker(q)
go b.dispatchWorker(dlq)
ctx := context.Background()
go b.dispatchWorker(ctx, q)
go b.dispatchWorker(ctx, dlq)
return q
}

View File

@@ -8,94 +8,97 @@ import (
var _ storage.IMap[string, any] = (*Map[string, any])(nil)
// Map is a thread-safe map using sync.Map with generics
type Map[K comparable, V any] struct {
data map[K]V
mu sync.RWMutex
m sync.Map
}
// New creates a new Map
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
data: make(map[K]V),
return &Map[K, V]{}
}
// Get retrieves the value for a given key
func (g *Map[K, V]) Get(key K) (V, bool) {
val, ok := g.m.Load(key)
if !ok {
var zeroValue V
return zeroValue, false
}
return val.(V), true
}
func (m *Map[K, V]) Get(key K) (V, bool) {
m.mu.RLock()
defer m.mu.RUnlock()
val, exists := m.data[key]
return val, exists
// Set adds a key-value pair to the map
func (g *Map[K, V]) Set(key K, value V) {
g.m.Store(key, value)
}
func (m *Map[K, V]) Set(key K, value V) {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = value
// Del removes a key-value pair from the map
func (g *Map[K, V]) Del(key K) {
g.m.Delete(key)
}
func (m *Map[K, V]) Del(key K) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
// ForEach iterates over the map
func (g *Map[K, V]) ForEach(fn func(K, V) bool) {
g.m.Range(func(k, v any) bool {
return fn(k.(K), v.(V))
})
}
func (m *Map[K, V]) ForEach(f func(K, V) bool) {
m.mu.RLock()
defer m.mu.RUnlock()
for k, v := range m.data {
if !f(k, v) {
break
}
}
// Clear removes all key-value pairs from the map
func (g *Map[K, V]) Clear() {
g.ForEach(func(k K, v V) bool {
g.Del(k)
return true
})
}
func (m *Map[K, V]) Clear() {
m.mu.Lock()
defer m.mu.Unlock()
m.data = make(map[K]V)
// Size returns the number of key-value pairs in the map
func (g *Map[K, V]) Size() int {
count := 0
g.ForEach(func(_ K, _ V) bool {
count++
return true
})
return count
}
func (m *Map[K, V]) Size() int {
m.mu.RLock()
defer m.mu.RUnlock()
return len(m.data)
}
func (m *Map[K, V]) Keys() []K {
m.mu.RLock()
defer m.mu.RUnlock()
keys := make([]K, 0, len(m.data))
for k := range m.data {
// Keys returns a slice of all keys in the map
func (g *Map[K, V]) Keys() []K {
keys := []K{}
g.ForEach(func(k K, _ V) bool {
keys = append(keys, k)
}
return true
})
return keys
}
func (m *Map[K, V]) Values() []V {
m.mu.RLock()
defer m.mu.RUnlock()
values := make([]V, 0, len(m.data))
for _, v := range m.data {
// Values returns a slice of all values in the map
func (g *Map[K, V]) Values() []V {
values := []V{}
g.ForEach(func(_ K, v V) bool {
values = append(values, v)
}
return true
})
return values
}
func (m *Map[K, V]) AsMap() map[K]V {
m.mu.RLock()
defer m.mu.RUnlock()
copiedMap := make(map[K]V, len(m.data))
for k, v := range m.data {
copiedMap[k] = v
}
return copiedMap
// AsMap returns a regular map containing all key-value pairs
func (g *Map[K, V]) AsMap() map[K]V {
result := make(map[K]V)
g.ForEach(func(k K, v V) bool {
result[k] = v
return true
})
return result
}
func (m *Map[K, V]) Clone() storage.IMap[K, V] {
m.mu.RLock()
defer m.mu.RUnlock()
clonedMap := New[K, V]()
for k, v := range m.data {
clonedMap.Set(k, v)
}
return clonedMap
// Clone creates a shallow copy of the map
func (g *Map[K, V]) Clone() storage.IMap[K, V] {
clone := New[K, V]()
g.ForEach(func(k K, v V) bool {
clone.Set(k, v)
return true
})
return clone
}

15
utils/conn.go Normal file
View File

@@ -0,0 +1,15 @@
package utils
import (
"net"
)
func localAddr(c net.Conn) string { return c.LocalAddr().String() }
func remoteAddr(c net.Conn) string { return c.RemoteAddr().String() }
func ConnectionsEqual(c1, c2 net.Conn) bool {
if c1 == nil || c2 == nil {
return false
}
return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2)
}