mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-27 04:15:52 +08:00
feat: Add connection
This commit is contained in:
61
broker.go
61
broker.go
@@ -75,18 +75,15 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
|
||||
})
|
||||
} else {
|
||||
b.consumers.ForEach(func(consumerID string, con *consumer) bool {
|
||||
if con.conn.RemoteAddr().String() == conn.RemoteAddr().String() &&
|
||||
con.conn.LocalAddr().String() == conn.LocalAddr().String() {
|
||||
if c, exists := b.consumers.Get(consumerID); exists {
|
||||
c.conn.Close()
|
||||
b.consumers.Del(consumerID)
|
||||
}
|
||||
if utils.ConnectionsEqual(conn, con.conn) {
|
||||
con.conn.Close()
|
||||
b.consumers.Del(consumerID)
|
||||
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
||||
queue.consumers.Del(consumerID)
|
||||
if _, ok := queue.consumers.Get(consumerID); ok {
|
||||
if b.opts.consumerOnClose != nil {
|
||||
b.opts.consumerOnClose(ctx, queue.name, consumerID)
|
||||
}
|
||||
queue.consumers.Del(consumerID)
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -189,7 +186,7 @@ func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
err := b.send(con.conn, msg)
|
||||
err := b.send(ctx, con.conn, msg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -202,7 +199,7 @@ func (b *Broker) Publish(ctx context.Context, task *Task, queue string) error {
|
||||
return err
|
||||
}
|
||||
msg := codec.NewMessage(consts.PUBLISH, payload, queue, headers.AsMap())
|
||||
b.broadcastToConsumers(ctx, msg)
|
||||
b.broadcastToConsumers(msg)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -212,10 +209,10 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
|
||||
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
||||
|
||||
ack := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||||
if err := b.send(conn, ack); err != nil {
|
||||
if err := b.send(ctx, conn, ack); err != nil {
|
||||
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
||||
}
|
||||
b.broadcastToConsumers(ctx, msg)
|
||||
b.broadcastToConsumers(msg)
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -227,7 +224,7 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
|
||||
func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
consumerID := b.AddConsumer(ctx, msg.Queue, conn)
|
||||
ack := codec.NewMessage(consts.SUBSCRIBE_ACK, nil, msg.Queue, msg.Headers)
|
||||
if err := b.send(conn, ack); err != nil {
|
||||
if err := b.send(ctx, conn, ack); err != nil {
|
||||
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
||||
}
|
||||
if b.opts.consumerOnSubscribe != nil {
|
||||
@@ -284,23 +281,23 @@ func (b *Broker) Start(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg)
|
||||
func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(ctx, conn, msg)
|
||||
}
|
||||
|
||||
func (b *Broker) receive(c net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(c)
|
||||
func (b *Broker) receive(ctx context.Context, c net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(ctx, c)
|
||||
}
|
||||
|
||||
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
||||
func (b *Broker) broadcastToConsumers(msg *codec.Message) {
|
||||
if queue, ok := b.queues.Get(msg.Queue); ok {
|
||||
task := &QueuedTask{Message: msg, RetryCount: 0}
|
||||
queue.tasks <- task
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) waitForConsumerAck(conn net.Conn) error {
|
||||
msg, err := b.receive(conn)
|
||||
func (b *Broker) waitForConsumerAck(ctx context.Context, conn net.Conn) error {
|
||||
msg, err := b.receive(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -360,12 +357,12 @@ func (b *Broker) RemoveConsumer(consumerID string, queues ...string) {
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
|
||||
func (b *Broker) handleConsumer(ctx context.Context, cmd consts.CMD, state consts.ConsumerState, consumerID string, queues ...string) {
|
||||
fn := func(queue *Queue) {
|
||||
con, ok := queue.consumers.Get(consumerID)
|
||||
if ok {
|
||||
ack := codec.NewMessage(cmd, utils.ToByte("{}"), queue.name, map[string]string{consts.ConsumerKey: consumerID})
|
||||
err := b.send(con.conn, ack)
|
||||
err := b.send(ctx, con.conn, ack)
|
||||
if err == nil {
|
||||
con.state = state
|
||||
}
|
||||
@@ -385,20 +382,20 @@ func (b *Broker) handleConsumer(cmd consts.CMD, state consts.ConsumerState, cons
|
||||
})
|
||||
}
|
||||
|
||||
func (b *Broker) PauseConsumer(consumerID string, queues ...string) {
|
||||
b.handleConsumer(consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
|
||||
func (b *Broker) PauseConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||
b.handleConsumer(ctx, consts.CONSUMER_PAUSE, consts.ConsumerStatePaused, consumerID, queues...)
|
||||
}
|
||||
|
||||
func (b *Broker) ResumeConsumer(consumerID string, queues ...string) {
|
||||
b.handleConsumer(consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
|
||||
func (b *Broker) ResumeConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||
b.handleConsumer(ctx, consts.CONSUMER_RESUME, consts.ConsumerStateActive, consumerID, queues...)
|
||||
}
|
||||
|
||||
func (b *Broker) StopConsumer(consumerID string, queues ...string) {
|
||||
b.handleConsumer(consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
|
||||
func (b *Broker) StopConsumer(ctx context.Context, consumerID string, queues ...string) {
|
||||
b.handleConsumer(ctx, consts.CONSUMER_STOP, consts.ConsumerStateStopped, consumerID, queues...)
|
||||
}
|
||||
|
||||
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||
msg, err := b.receive(c)
|
||||
msg, err := b.receive(ctx, c)
|
||||
if err == nil {
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
b.OnMessage(ctx, msg, c)
|
||||
@@ -412,12 +409,12 @@ func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Broker) dispatchWorker(queue *Queue) {
|
||||
func (b *Broker) dispatchWorker(ctx context.Context, queue *Queue) {
|
||||
delay := b.opts.initialDelay
|
||||
for task := range queue.tasks {
|
||||
success := false
|
||||
for !success && task.RetryCount <= b.opts.maxRetries {
|
||||
if b.dispatchTaskToConsumer(queue, task) {
|
||||
if b.dispatchTaskToConsumer(ctx, queue, task) {
|
||||
success = true
|
||||
} else {
|
||||
task.RetryCount++
|
||||
@@ -440,7 +437,7 @@ func (b *Broker) sendToDLQ(queue *Queue, task *QueuedTask) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
||||
func (b *Broker) dispatchTaskToConsumer(ctx context.Context, queue *Queue, task *QueuedTask) bool {
|
||||
var consumerFound bool
|
||||
var err error
|
||||
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
||||
@@ -448,7 +445,7 @@ func (b *Broker) dispatchTaskToConsumer(queue *Queue, task *QueuedTask) bool {
|
||||
err = fmt.Errorf("consumer %s is not active", con.id)
|
||||
return true
|
||||
}
|
||||
if err := b.send(con.conn, task.Message); err == nil {
|
||||
if err := b.send(ctx, con.conn, task.Message); err == nil {
|
||||
consumerFound = true
|
||||
return false
|
||||
}
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -12,11 +14,13 @@ type Message struct {
|
||||
Headers map[string]string `msgpack:"h"`
|
||||
Queue string `msgpack:"q"`
|
||||
Payload []byte `msgpack:"p"`
|
||||
m sync.RWMutex
|
||||
Command consts.CMD `msgpack:"c"`
|
||||
Command consts.CMD `msgpack:"c"`
|
||||
}
|
||||
|
||||
func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string]string) *Message {
|
||||
if headers == nil {
|
||||
headers = make(map[string]string)
|
||||
}
|
||||
return &Message{
|
||||
Headers: headers,
|
||||
Queue: queue,
|
||||
@@ -26,8 +30,6 @@ func NewMessage(cmd consts.CMD, payload []byte, queue string, headers map[string
|
||||
}
|
||||
|
||||
func (m *Message) Serialize() ([]byte, error) {
|
||||
m.m.RLock()
|
||||
defer m.m.RUnlock()
|
||||
data, err := Marshal(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -40,37 +42,66 @@ func Deserialize(data []byte) (*Message, error) {
|
||||
if err := Unmarshal(data, &msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
func SendMessage(conn net.Conn, msg *Message) error {
|
||||
var byteBufferPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, 0, 4096)
|
||||
},
|
||||
}
|
||||
|
||||
func SendMessage(ctx context.Context, conn net.Conn, msg *Message) error {
|
||||
data, err := msg.Serialize()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
length := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(length, uint32(len(data)))
|
||||
totalLength := 4 + len(data)
|
||||
buffer := byteBufferPool.Get().([]byte)
|
||||
if cap(buffer) < totalLength {
|
||||
buffer = make([]byte, totalLength)
|
||||
} else {
|
||||
buffer = buffer[:totalLength]
|
||||
}
|
||||
defer byteBufferPool.Put(buffer)
|
||||
|
||||
if _, err := conn.Write(length); err != nil {
|
||||
return err
|
||||
binary.BigEndian.PutUint32(buffer[:4], uint32(len(data)))
|
||||
copy(buffer[4:], data)
|
||||
|
||||
writer := bufio.NewWriter(conn)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
if _, err := writer.Write(buffer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
return writer.Flush()
|
||||
}
|
||||
|
||||
func ReadMessage(conn net.Conn) (*Message, error) {
|
||||
func ReadMessage(ctx context.Context, conn net.Conn) (*Message, error) {
|
||||
lengthBytes := make([]byte, 4)
|
||||
if _, err := conn.Read(lengthBytes); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
length := binary.BigEndian.Uint32(lengthBytes)
|
||||
data := make([]byte, length)
|
||||
if _, err := conn.Read(data); err != nil {
|
||||
return nil, err
|
||||
data := byteBufferPool.Get().([]byte)[:length]
|
||||
defer byteBufferPool.Put(data)
|
||||
totalRead := 0
|
||||
for totalRead < int(length) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
n, err := conn.Read(data[totalRead:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
totalRead += n
|
||||
}
|
||||
}
|
||||
return Deserialize(data)
|
||||
return Deserialize(data[:length])
|
||||
}
|
||||
|
51
consumer.go
51
consumer.go
@@ -7,7 +7,6 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq/codec"
|
||||
@@ -47,12 +46,12 @@ func NewConsumer(id string, queue string, handler Handler, opts ...Option) *Cons
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Consumer) send(conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(conn, msg)
|
||||
func (c *Consumer) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||||
return codec.SendMessage(ctx, conn, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) receive(conn net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(conn)
|
||||
func (c *Consumer) receive(ctx context.Context, conn net.Conn) (*codec.Message, error) {
|
||||
return codec.ReadMessage(ctx, conn)
|
||||
}
|
||||
|
||||
func (c *Consumer) Close() error {
|
||||
@@ -75,10 +74,10 @@ func (c *Consumer) SetKey(key string) {
|
||||
func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
msg := codec.NewMessage(consts.SUBSCRIBE, utils.ToByte("{}"), queue, headers)
|
||||
if err := c.send(c.conn, msg); err != nil {
|
||||
if err := c.send(ctx, c.conn, msg); err != nil {
|
||||
return fmt.Errorf("error while trying to subscribe: %v", err)
|
||||
}
|
||||
return c.waitForAck(c.conn)
|
||||
return c.waitForAck(ctx, c.conn)
|
||||
}
|
||||
|
||||
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
|
||||
@@ -122,7 +121,7 @@ func (c *Consumer) sendMessageAck(ctx context.Context, msg *codec.Message, conn
|
||||
headers := HeadersWithConsumerIDAndQueue(ctx, c.id, msg.Queue)
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
reply := codec.NewMessage(consts.MESSAGE_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, headers)
|
||||
if err := c.send(conn, reply); err != nil {
|
||||
if err := c.send(ctx, conn, reply); err != nil {
|
||||
fmt.Printf("failed to send MESSAGE_ACK for queue %s: %v", msg.Queue, err)
|
||||
}
|
||||
}
|
||||
@@ -172,7 +171,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
||||
if result.Payload != nil || result.Error != nil {
|
||||
bt, _ := json.Marshal(result)
|
||||
reply := codec.NewMessage(consts.MESSAGE_RESPONSE, bt, result.Topic, headers)
|
||||
if err := c.send(c.conn, reply); err != nil {
|
||||
if err := c.send(ctx, c.conn, reply); err != nil {
|
||||
return fmt.Errorf("failed to send MESSAGE_RESPONSE: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -182,7 +181,7 @@ func (c *Consumer) OnResponse(ctx context.Context, result Result) error {
|
||||
func (c *Consumer) sendDenyMessage(ctx context.Context, taskID, queue string, err error) {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
reply := codec.NewMessage(consts.MESSAGE_DENY, utils.ToByte(fmt.Sprintf(`{"id":"%s", "error":"%s"}`, taskID, err.Error())), queue, headers)
|
||||
if sendErr := c.send(c.conn, reply); sendErr != nil {
|
||||
if sendErr := c.send(ctx, c.conn, reply); sendErr != nil {
|
||||
log.Printf("failed to send MESSAGE_DENY for task %s: %v", taskID, sendErr)
|
||||
}
|
||||
}
|
||||
@@ -209,7 +208,7 @@ func (c *Consumer) attemptConnect() error {
|
||||
}
|
||||
|
||||
func (c *Consumer) readMessage(ctx context.Context, conn net.Conn) error {
|
||||
msg, err := c.receive(conn)
|
||||
msg, err := c.receive(ctx, conn)
|
||||
if err == nil {
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
return c.OnMessage(ctx, msg, conn)
|
||||
@@ -235,24 +234,32 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to connect to server for queue %s: %v", c.queue, err)
|
||||
}
|
||||
c.pool.Start(c.opts.numOfWorkers)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
stopChan := make(chan struct{})
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer close(stopChan) // Signal completion when done
|
||||
for {
|
||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||
log.Println("Error reading message:", err)
|
||||
break
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Println("Context canceled, stopping message reading.")
|
||||
return
|
||||
default:
|
||||
if err := c.readMessage(ctx, c.conn); err != nil {
|
||||
log.Println("Error reading message:", err)
|
||||
return // Exit the goroutine on error
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
select {
|
||||
case <-stopChan:
|
||||
case <-ctx.Done():
|
||||
log.Println("Context canceled, performing cleanup.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) waitForAck(conn net.Conn) error {
|
||||
msg, err := c.receive(conn)
|
||||
func (c *Consumer) waitForAck(ctx context.Context, conn net.Conn) error {
|
||||
msg, err := c.receive(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -286,7 +293,7 @@ func (c *Consumer) operate(ctx context.Context, cmd consts.CMD, poolOperation fu
|
||||
func (c *Consumer) sendOpsMessage(ctx context.Context, cmd consts.CMD) error {
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
msg := codec.NewMessage(cmd, nil, c.queue, headers)
|
||||
return c.send(c.conn, msg)
|
||||
return c.send(ctx, c.conn, msg)
|
||||
}
|
||||
|
||||
func (c *Consumer) Conn() net.Conn {
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
consumer1 := mq.NewConsumer("F", "queue1", tasks.Node6{}.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
|
||||
n := &tasks.Node6{}
|
||||
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, mq.WithWorkerPool(100, 4, 50000))
|
||||
consumer1.Consume(context.Background())
|
||||
}
|
||||
|
@@ -3,7 +3,6 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
@@ -25,7 +24,6 @@ func main() {
|
||||
Payload: payload,
|
||||
}
|
||||
for i := 0; i < 100; i++ {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
err := publisher.Publish(context.Background(), task, "queue1")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
14
publisher.go
14
publisher.go
@@ -37,15 +37,15 @@ func (p *Publisher) send(ctx context.Context, queue string, task Task, conn net.
|
||||
return err
|
||||
}
|
||||
msg := codec.NewMessage(command, payload, queue, headers)
|
||||
if err := codec.SendMessage(conn, msg); err != nil {
|
||||
if err := codec.SendMessage(ctx, conn, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.waitForAck(conn)
|
||||
return p.waitForAck(ctx, conn)
|
||||
}
|
||||
|
||||
func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||
msg, err := codec.ReadMessage(conn)
|
||||
func (p *Publisher) waitForAck(ctx context.Context, conn net.Conn) error {
|
||||
msg, err := codec.ReadMessage(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -57,8 +57,8 @@ func (p *Publisher) waitForAck(conn net.Conn) error {
|
||||
return fmt.Errorf("expected PUBLISH_ACK, got: %v", msg.Command)
|
||||
}
|
||||
|
||||
func (p *Publisher) waitForResponse(conn net.Conn) Result {
|
||||
msg, err := codec.ReadMessage(conn)
|
||||
func (p *Publisher) waitForResponse(ctx context.Context, conn net.Conn) Result {
|
||||
msg, err := codec.ReadMessage(ctx, conn)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
}
|
||||
@@ -103,7 +103,7 @@ func (p *Publisher) Request(ctx context.Context, queue string, task Task) Result
|
||||
resultCh := make(chan Result)
|
||||
go func() {
|
||||
defer close(resultCh)
|
||||
resultCh <- p.waitForResponse(conn)
|
||||
resultCh <- p.waitForResponse(ctx, conn)
|
||||
}()
|
||||
finalResult := <-resultCh
|
||||
return finalResult
|
||||
|
6
queue.go
6
queue.go
@@ -36,9 +36,9 @@ func (b *Broker) NewQueue(name string) *Queue {
|
||||
consumers: memory.New[string, *consumer](),
|
||||
}
|
||||
b.deadLetter.Set(name, dlq)
|
||||
|
||||
go b.dispatchWorker(q)
|
||||
go b.dispatchWorker(dlq)
|
||||
ctx := context.Background()
|
||||
go b.dispatchWorker(ctx, q)
|
||||
go b.dispatchWorker(ctx, dlq)
|
||||
return q
|
||||
}
|
||||
|
||||
|
@@ -8,94 +8,97 @@ import (
|
||||
|
||||
var _ storage.IMap[string, any] = (*Map[string, any])(nil)
|
||||
|
||||
// Map is a thread-safe map using sync.Map with generics
|
||||
type Map[K comparable, V any] struct {
|
||||
data map[K]V
|
||||
mu sync.RWMutex
|
||||
m sync.Map
|
||||
}
|
||||
|
||||
// New creates a new Map
|
||||
func New[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{
|
||||
data: make(map[K]V),
|
||||
return &Map[K, V]{}
|
||||
}
|
||||
|
||||
// Get retrieves the value for a given key
|
||||
func (g *Map[K, V]) Get(key K) (V, bool) {
|
||||
val, ok := g.m.Load(key)
|
||||
if !ok {
|
||||
var zeroValue V
|
||||
return zeroValue, false
|
||||
}
|
||||
return val.(V), true
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
val, exists := m.data[key]
|
||||
return val, exists
|
||||
// Set adds a key-value pair to the map
|
||||
func (g *Map[K, V]) Set(key K, value V) {
|
||||
g.m.Store(key, value)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Set(key K, value V) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = value
|
||||
// Del removes a key-value pair from the map
|
||||
func (g *Map[K, V]) Del(key K) {
|
||||
g.m.Delete(key)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Del(key K) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
// ForEach iterates over the map
|
||||
func (g *Map[K, V]) ForEach(fn func(K, V) bool) {
|
||||
g.m.Range(func(k, v any) bool {
|
||||
return fn(k.(K), v.(V))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) ForEach(f func(K, V) bool) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
for k, v := range m.data {
|
||||
if !f(k, v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Clear removes all key-value pairs from the map
|
||||
func (g *Map[K, V]) Clear() {
|
||||
g.ForEach(func(k K, v V) bool {
|
||||
g.Del(k)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Clear() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data = make(map[K]V)
|
||||
// Size returns the number of key-value pairs in the map
|
||||
func (g *Map[K, V]) Size() int {
|
||||
count := 0
|
||||
g.ForEach(func(_ K, _ V) bool {
|
||||
count++
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Size() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.data)
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Keys() []K {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
keys := make([]K, 0, len(m.data))
|
||||
for k := range m.data {
|
||||
// Keys returns a slice of all keys in the map
|
||||
func (g *Map[K, V]) Keys() []K {
|
||||
keys := []K{}
|
||||
g.ForEach(func(k K, _ V) bool {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return keys
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Values() []V {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
values := make([]V, 0, len(m.data))
|
||||
for _, v := range m.data {
|
||||
// Values returns a slice of all values in the map
|
||||
func (g *Map[K, V]) Values() []V {
|
||||
values := []V{}
|
||||
g.ForEach(func(_ K, v V) bool {
|
||||
values = append(values, v)
|
||||
}
|
||||
return true
|
||||
})
|
||||
return values
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) AsMap() map[K]V {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
copiedMap := make(map[K]V, len(m.data))
|
||||
for k, v := range m.data {
|
||||
copiedMap[k] = v
|
||||
}
|
||||
return copiedMap
|
||||
// AsMap returns a regular map containing all key-value pairs
|
||||
func (g *Map[K, V]) AsMap() map[K]V {
|
||||
result := make(map[K]V)
|
||||
g.ForEach(func(k K, v V) bool {
|
||||
result[k] = v
|
||||
return true
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *Map[K, V]) Clone() storage.IMap[K, V] {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
clonedMap := New[K, V]()
|
||||
for k, v := range m.data {
|
||||
clonedMap.Set(k, v)
|
||||
}
|
||||
return clonedMap
|
||||
// Clone creates a shallow copy of the map
|
||||
func (g *Map[K, V]) Clone() storage.IMap[K, V] {
|
||||
clone := New[K, V]()
|
||||
g.ForEach(func(k K, v V) bool {
|
||||
clone.Set(k, v)
|
||||
return true
|
||||
})
|
||||
return clone
|
||||
}
|
||||
|
15
utils/conn.go
Normal file
15
utils/conn.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func localAddr(c net.Conn) string { return c.LocalAddr().String() }
|
||||
func remoteAddr(c net.Conn) string { return c.RemoteAddr().String() }
|
||||
|
||||
func ConnectionsEqual(c1, c2 net.Conn) bool {
|
||||
if c1 == nil || c2 == nil {
|
||||
return false
|
||||
}
|
||||
return localAddr(c1) == localAddr(c2) && remoteAddr(c1) == remoteAddr(c2)
|
||||
}
|
Reference in New Issue
Block a user