feat: Add connection

This commit is contained in:
sujit
2024-10-20 23:24:58 +05:45
parent a06396da56
commit 35a79be4ad
9 changed files with 201 additions and 149 deletions

View File

@@ -75,18 +75,15 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
}) })
} else { } else {
b.consumers.ForEach(func(consumerID string, con *consumer) bool { b.consumers.ForEach(func(consumerID string, con *consumer) bool {
if con.conn.RemoteAddr().String() == conn.RemoteAddr().String() && if utils.ConnectionsEqual(conn, con.conn) {
con.conn.LocalAddr().String() == conn.LocalAddr().String() { con.conn.Close()
if c, exists := b.consumers.Get(consumerID); exists { b.consumers.Del(consumerID)
c.conn.Close()
b.consumers.Del(consumerID)
}
b.queues.ForEach(func(_ string, queue *Queue) bool { b.queues.ForEach(func(_ string, queue *Queue) bool {
queue.consumers.Del(consumerID)
if _, ok := queue.consumers.Get(consumerID); ok { if _, ok := queue.consumers.Get(consumerID); ok {
if b.opts.consumerOnClose != nil { if b.opts.consumerOnClose != nil {
b.opts.consumerOnClose(ctx, queue.name, consumerID) b.opts.consumerOnClose(ctx, queue.name, consumerID)
} }
queue.consumers.Del(consumerID)
} }
return true return true
}) })
@@ -189,7 +186,7 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message)
if !ok { if !ok {
return return
} }
err := b.send(con.conn, msg) err := b.send(ctx, con.conn, msg)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -202,7 +199,7 @@ func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error {
return err return err
} }
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap()) msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
b.broadcastToConsumers(ctx, msg) b.broadcastToConsumers(msg)
return nil return nil
} }
@@ -212,10 +209,10 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID) log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil { if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending PUBLISH_ACK: %v\n", err) log.Printf("Error sending PUBLISH_ACK: %v\n", err)
} }
b.broadcastToConsumers(ctx, msg) b.broadcastToConsumers(msg)
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -227,7 +224,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
consumerID := b.AddConsumer(ctx, msg.Queue, conn) consumerID := b.AddConsumer(ctx, msg.Queue, conn)
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers) ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
if err := b.send(conn, ack); err != nil { if err := b.send(ctx, conn, ack); err != nil {
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err) log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
} }
if b.opts.consumerOnSubscribe != nil { if b.opts.consumerOnSubscribe != nil {
@@ -284,23 +281,23 @@ func (b *Broker) Start(ctx context.Context) error {
} }
} }
func (b *Broker) send(conn net.Conn, msg *codec.Message) error { func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg) return codec.SendMessage(ctx, conn, msg)
} }
func (b *Broker) receive(c net.Conn) (*codec.Message, error) { func (b *Broker) receive(ctx context.Context, c net.Conn) (*codec.Message, error) {
return codec.ReadMessage(c) return codec.ReadMessage(ctx, c)
} }
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) { func (b *Broker) broadcastToConsumers(msg *codec.Message) {
if queue, ok := b.queues.Get(msg.Queue); ok { if queue, ok := b.queues.Get(msg.Queue); ok {
task := &QueuedTask{Message: msg, RetryCount: 0} task := &QueuedTask{Message: msg, RetryCount: 0}
queue.tasks <- task queue.tasks <- task
} }
} }
func (b *Broker) waitForConsumerAck(conn net.Conn) error { func (b *Broker) waitForConsumerAck(ctx context.Context, conn net.Conn) error {
msg, err := b.receive(conn) msg, err := b.receive(ctx, conn)
if err != nil { if err != nil {
return err return err
} }
@@ -360,12 +357,12 @@ func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
}) })
} }
func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) { func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
fn := func(queue *Queue) { fn := func(queue *Queue) {
con, ok := queue.consumers.Get(consumerID) con, ok := queue.consumers.Get(consumerID)
if ok { if ok {
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID}) ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
err := b.send(con.conn, ack) err := b.send(ctx, con.conn, ack)
if err == nil { if err == nil {
con.state = state con.state = state
} }
@@ -385,20 +382,20 @@ func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, cons
}) })
} }
func (b *Broker) PauseConsumer(consumerID string, queues ...string) { func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...) b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
} }
func (b *Broker) ResumeConsumer(consumerID string, queues ...string) { func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...) b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
} }
func (b *Broker) StopConsumer(consumerID string, queues ...string) { func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) {
b.handleConsumer(consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...) b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
} }
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error { func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
msg, err := b.receive(c) msg, err := b.receive(ctx, c)
if err == nil { if err == nil {
ctx = SetHeaders(ctx, msg.Headers) ctx = SetHeaders(ctx, msg.Headers)
b.OnMessage(ctx, msg, c) b.OnMessage(ctx, msg, c)
@@ -412,12 +409,12 @@ func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
return err return err
} }
func (b *Broker) dispatchWorker(queue *Queue) { func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) {
delay := b.opts.initialDelay delay := b.opts.initialDelay
for task := range queue.tasks { for task := range queue.tasks {
success := false success := false
for !success && task.RetryCount <= b.opts.maxRetries { for !success && task.RetryCount <= b.opts.maxRetries {
if b.dispatchTaskToConsumer(queue, task) { if b.dispatchTaskToConsumer(ctx, queue, task) {
success = true success = true
} else { } else {
task.RetryCount++ task.RetryCount++
@@ -440,7 +437,7 @@ func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
} }
} }
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool { func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task *QueuedTask) bool {
var consumerFound bool var consumerFound bool
var err error var err error
queue.consumers.ForEach(func(_ string, con *consumer) bool { queue.consumers.ForEach(func(_ string, con *consumer) bool {
@@ -448,7 +445,7 @@ func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
err = fmt.Errorf("consumer %s is not active", con.id) err = fmt.Errorf("consumer %s is not active", con.id)
return true return true
} }
if err := b.send(con.conn, task.Message); err == nil { if err := b.send(ctx, con.conn, task.Message); err == nil {
consumerFound = true consumerFound = true
return false return false
} }

View File

@@ -1,6 +1,8 @@
package codec package codec
import ( import (
"bufio"
"context"
"encoding/binary" "encoding/binary"
"net" "net"
"sync" "sync"
@@ -12,11 +14,13 @@ type Message struct {
Headers map[string]string `msgpack:"h"` Headers map[string]string `msgpack:"h"`
Queue string `msgpack:"q"` Queue string `msgpack:"q"`
Payload []byte `msgpack:"p"` Payload []byte `msgpack:"p"`
m sync.RWMutex Command consts.CMD `msgpack:"c"`
Command consts.CMD `msgpack:"c"`
} }
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message { func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message {
if headers == nil {
headers = make(map[string]string)
}
return &Message{ return &Message{
Headers: headers, Headers: headers,
Queue: queue, Queue: queue,
@@ -26,8 +30,6 @@ func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string
} }
func (m *Message) Serialize() ([]byte, error) { func (m *Message) Serialize() ([]byte, error) {
m.m.RLock()
defer m.m.RUnlock()
data, err := Marshal(m) data, err := Marshal(m)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -40,37 +42,66 @@ func Deserialize(data []byte) (*Message, error) {
if err := Unmarshal(data, &msg); err != nil { if err := Unmarshal(data, &msg); err != nil {
return nil, err return nil, err
} }
return &msg, nil return &msg, nil
} }
func SendMessage(conn net.Conn, msg *Message) error { var byteBufferPool = sync.Pool{
New: func() interface{} {
return make([]byte, 0, 4096)
},
}
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
data, err := msg.Serialize() data, err := msg.Serialize()
if err != nil { if err != nil {
return err return err
} }
length := make([]byte, 4) totalLength := 4 + len(data)
binary.BigEndian.PutUint32(length, uint32(len(data))) buffer := byteBufferPool.Get().([]byte)
if cap(buffer) < totalLength {
buffer = make([]byte, totalLength)
} else {
buffer = buffer[:totalLength]
}
defer byteBufferPool.Put(buffer)
if _, err := conn.Write(length); err != nil { binary.BigEndian.PutUint32(buffer[:4], uint32(len(data)))
return err copy(buffer[4:], data)
writer := bufio.NewWriter(conn)
select {
case <-ctx.Done():
return ctx.Err()
default:
if _, err := writer.Write(buffer); err != nil {
return err
}
} }
if _, err := conn.Write(data); err != nil {
return err return writer.Flush()
}
return nil
} }
func ReadMessage(conn net.Conn) (*Message, error) { func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
lengthBytes := make([]byte, 4) lengthBytes := make([]byte, 4)
if _, err := conn.Read(lengthBytes); err != nil { if _, err := conn.Read(lengthBytes); err != nil {
return nil, err return nil, err
} }
length := binary.BigEndian.Uint32(lengthBytes) length := binary.BigEndian.Uint32(lengthBytes)
data := make([]byte, length) data := byteBufferPool.Get().([]byte)[:length]
if _, err := conn.Read(data); err != nil { defer byteBufferPool.Put(data)
return nil, err totalRead := 0
for totalRead < int(length) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
n, err := conn.Read(data[totalRead:])
if err != nil {
return nil, err
}
totalRead += n
}
} }
return Deserialize(data) return Deserialize(data[:length])
} }

View File

@@ -7,7 +7,6 @@ import (
"log" "log"
"net" "net"
"strings" "strings"
"sync"
"time" "time"
"github.com/oarkflow/mq/codec" "github.com/oarkflow/mq/codec"
@@ -47,12 +46,12 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons
} }
} }
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error { func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
return codec.SendMessage(conn, msg) return codec.SendMessage(ctx, conn, msg)
} }
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) { func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
return codec.ReadMessage(conn) return codec.ReadMessage(ctx, conn)
} }
func (c *Consumer) Close() error { func (c *Consumer) Close() error {
@@ -75,10 +74,10 @@ func (c *Consumer) SetKey(key string) {
func (c *Consumer) subscribe(ctx context.Context, queue string) error { func (c *Consumer) subscribe(ctx context.Context, queue string) error {
headers := HeadersWithConsumerID(ctx, c.id) headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers) msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
if err := c.send(c.conn, msg); err != nil { if err := c.send(ctx, c.conn, msg); err != nil {
return fmt.Errorf("error while trying to subscribe: %v", err) return fmt.Errorf("error while trying to subscribe: %v", err)
} }
return c.waitForAck(c.conn) return c.waitForAck(ctx, c.conn)
} }
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error { func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
@@ -122,7 +121,7 @@ func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue) headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
taskID, _ := jsonparser.GetString(msg.Payload, "id") taskID, _ := jsonparser.GetString(msg.Payload, "id")
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers) reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
if err := c.send(conn, reply); err != nil { if err := c.send(ctx, conn, reply); err != nil {
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err) fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
} }
} }
@@ -172,7 +171,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
if result.Payload != nil || result.Error != nil { if result.Payload != nil || result.Error != nil {
bt, _ := json.Marshal(result) bt, _ := json.Marshal(result)
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers) reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
if err := c.send(c.conn, reply); err != nil { if err := c.send(ctx, c.conn, reply); err != nil {
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err) return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
} }
} }
@@ -182,7 +181,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) { func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
headers := HeadersWithConsumerID(ctx, c.id) headers := HeadersWithConsumerID(ctx, c.id)
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers) reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
if sendErr := c.send(c.conn, reply); sendErr != nil { if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr) log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
} }
} }
@@ -209,7 +208,7 @@ func (c *Consumer) attemptConnect() error {
} }
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error { func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(conn) msg, err := c.receive(ctx, conn)
if err == nil { if err == nil {
ctx = SetHeaders(ctx, msg.Headers) ctx = SetHeaders(ctx, msg.Headers)
return c.OnMessage(ctx, msg, conn) return c.OnMessage(ctx, msg, conn)
@@ -235,24 +234,32 @@ func (c *Consumer) Consume(ctx context.Context) error {
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err) return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
} }
c.pool.Start(c.opts.numOfWorkers) c.pool.Start(c.opts.numOfWorkers)
var wg sync.WaitGroup stopChan := make(chan struct{})
wg.Add(1)
go func() { go func() {
defer wg.Done() defer close(stopChan) // Signal completion when done
for { for {
if err := c.readMessage(ctx, c.conn); err != nil { select {
log.Println("Error reading message:", err) case <-ctx.Done():
break log.Println("Context canceled, stopping message reading.")
return
default:
if err := c.readMessage(ctx, c.conn); err != nil {
log.Println("Error reading message:", err)
return // Exit the goroutine on error
}
} }
} }
}() }()
select {
wg.Wait() case <-stopChan:
case <-ctx.Done():
log.Println("Context canceled, performing cleanup.")
}
return nil return nil
} }
func (c *Consumer) waitForAck(conn net.Conn) error { func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
msg, err := c.receive(conn) msg, err := c.receive(ctx, conn)
if err != nil { if err != nil {
return err return err
} }
@@ -286,7 +293,7 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error { func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
headers := HeadersWithConsumerID(ctx, c.id) headers := HeadersWithConsumerID(ctx, c.id)
msg := codec.NewMessage(cmd, nil, c.queue, headers) msg := codec.NewMessage(cmd, nil, c.queue, headers)
return c.send(c.conn, msg) return c.send(ctx, c.conn, msg)
} }
func (c *Consumer) Conn() net.Conn { func (c *Consumer) Conn() net.Conn {

View File

@@ -9,6 +9,7 @@ import (
) )
func main() { func main() {
consumer1 := mq.NewConsumer("F", "queue1", tasks.Node6{}.ProcessTask, mq.WithWorkerPool(100, 4, 50000)) n := &tasks.Node6{}
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
consumer1.Consume(context.Background()) consumer1.Consume(context.Background())
} }

View File

@@ -3,7 +3,6 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/oarkflow/mq" "github.com/oarkflow/mq"
) )
@@ -25,7 +24,6 @@ func main() {
Payload: payload, Payload: payload,
} }
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
time.Sleep(500 * time.Millisecond)
err := publisher.Publish(context.Background(), task, "queue1") err := publisher.Publish(context.Background(), task, "queue1")
if err != nil { if err != nil {
panic(err) panic(err)

View File

@@ -37,15 +37,15 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.
return err return err
} }
msg := codec.NewMessage(command, payload, queue, headers) msg := codec.NewMessage(command, payload, queue, headers)
if err := codec.SendMessage(conn, msg); err != nil { if err := codec.SendMessage(ctx, conn, msg); err != nil {
return err return err
} }
return p.waitForAck(conn) return p.waitForAck(ctx, conn)
} }
func (p *Publisher) waitForAck(conn net.Conn) error { func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error {
msg, err := codec.ReadMessage(conn) msg, err := codec.ReadMessage(ctx, conn)
if err != nil { if err != nil {
return err return err
} }
@@ -57,8 +57,8 @@ func (p *Publisher) waitForAck(conn net.Conn) error {
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command) return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
} }
func (p *Publisher) waitForResponse(conn net.Conn) Result { func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result {
msg, err := codec.ReadMessage(conn) msg, err := codec.ReadMessage(ctx, conn)
if err != nil { if err != nil {
return Result{Error: err} return Result{Error: err}
} }
@@ -103,7 +103,7 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result
resultCh := make(chan Result) resultCh := make(chan Result)
go func() { go func() {
defer close(resultCh) defer close(resultCh)
resultCh <- p.waitForResponse(conn) resultCh <- p.waitForResponse(ctx, conn)
}() }()
finalResult := <-resultCh finalResult := <-resultCh
return finalResult return finalResult

View File

@@ -36,9 +36,9 @@ func (b *Broker) NewQueue(name string) *Queue {
consumers: memory.New[string, *consumer](), consumers: memory.New[string, *consumer](),
} }
b.deadLetter.Set(name, dlq) b.deadLetter.Set(name, dlq)
ctx := context.Background()
go b.dispatchWorker(q) go b.dispatchWorker(ctx, q)
go b.dispatchWorker(dlq) go b.dispatchWorker(ctx, dlq)
return q return q
} }

View File

@@ -8,94 +8,97 @@ import (
var _ storage.IMap[string, any] = (*Map[string, any])(nil) var _ storage.IMap[string, any] = (*Map[string, any])(nil)
// Map is a thread-safe map using sync.Map with generics
type Map[K comparable, V any] struct { type Map[K comparable, V any] struct {
data map[K]V m sync.Map
mu sync.RWMutex
} }
// New creates a new Map
func New[K comparable, V any]() *Map[K, V] { func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{ return &Map[K, V]{}
data: make(map[K]V), }
// Get retrieves the value for a given key
func (g *Map[K, V]) Get(key K) (V, bool) {
val, ok := g.m.Load(key)
if !ok {
var zeroValue V
return zeroValue, false
} }
return val.(V), true
} }
func (m *Map[K, V]) Get(key K) (V, bool) { // Set adds a key-value pair to the map
m.mu.RLock() func (g *Map[K, V]) Set(key K, value V) {
defer m.mu.RUnlock() g.m.Store(key, value)
val, exists := m.data[key]
return val, exists
} }
func (m *Map[K, V]) Set(key K, value V) { // Del removes a key-value pair from the map
m.mu.Lock() func (g *Map[K, V]) Del(key K) {
defer m.mu.Unlock() g.m.Delete(key)
m.data[key] = value
} }
func (m *Map[K, V]) Del(key K) { // ForEach iterates over the map
m.mu.Lock() func (g *Map[K, V]) ForEach(fn func(K, V) bool) {
defer m.mu.Unlock() g.m.Range(func(k, v any) bool {
delete(m.data, key) return fn(k.(K), v.(V))
})
} }
func (m *Map[K, V]) ForEach(f func(K, V) bool) { // Clear removes all key-value pairs from the map
m.mu.RLock() func (g *Map[K, V]) Clear() {
defer m.mu.RUnlock() g.ForEach(func(k K, v V) bool {
for k, v := range m.data { g.Del(k)
if !f(k, v) { return true
break })
}
}
} }
func (m *Map[K, V]) Clear() { // Size returns the number of key-value pairs in the map
m.mu.Lock() func (g *Map[K, V]) Size() int {
defer m.mu.Unlock() count := 0
m.data = make(map[K]V) g.ForEach(func(_ K, _ V) bool {
count++
return true
})
return count
} }
func (m *Map[K, V]) Size() int { // Keys returns a slice of all keys in the map
m.mu.RLock() func (g *Map[K, V]) Keys() []K {
defer m.mu.RUnlock() keys := []K{}
return len(m.data) g.ForEach(func(k K, _ V) bool {
}
func (m *Map[K, V]) Keys() []K {
m.mu.RLock()
defer m.mu.RUnlock()
keys := make([]K, 0, len(m.data))
for k := range m.data {
keys = append(keys, k) keys = append(keys, k)
} return true
})
return keys return keys
} }
func (m *Map[K, V]) Values() []V { // Values returns a slice of all values in the map
m.mu.RLock() func (g *Map[K, V]) Values() []V {
defer m.mu.RUnlock() values := []V{}
values := make([]V, 0, len(m.data)) g.ForEach(func(_ K, v V) bool {
for _, v := range m.data {
values = append(values, v) values = append(values, v)
} return true
})
return values return values
} }
func (m *Map[K, V]) AsMap() map[K]V { // AsMap returns a regular map containing all key-value pairs
m.mu.RLock() func (g *Map[K, V]) AsMap() map[K]V {
defer m.mu.RUnlock() result := make(map[K]V)
copiedMap := make(map[K]V, len(m.data)) g.ForEach(func(k K, v V) bool {
for k, v := range m.data { result[k] = v
copiedMap[k] = v return true
} })
return copiedMap return result
} }
func (m *Map[K, V]) Clone() storage.IMap[K, V] { // Clone creates a shallow copy of the map
m.mu.RLock() func (g *Map[K, V]) Clone() storage.IMap[K, V] {
defer m.mu.RUnlock() clone := New[K, V]()
clonedMap := New[K, V]() g.ForEach(func(k K, v V) bool {
for k, v := range m.data { clone.Set(k, v)
clonedMap.Set(k, v) return true
} })
return clonedMap return clone
} }

15
utils/conn.go Normal file
View File

@@ -0,0 +1,15 @@
package utils
import (
"net"
)
func localAddr(c net.Conn) string { return c.LocalAddr().String() }
func remoteAddr(c net.Conn) string { return c.RemoteAddr().String() }
func ConnectionsEqual(c1, c2 net.Conn) bool {
if c1 == nil || c2 == nil {
return false
}
return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2)
}