mirror of
https://github.com/oarkflow/mq.git
synced 2025-10-06 16:36:53 +08:00
333 lines
8.2 KiB
Go
333 lines
8.2 KiB
Go
package v2
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/oarkflow/xsync"
|
|
|
|
"github.com/oarkflow/mq/codec"
|
|
"github.com/oarkflow/mq/consts"
|
|
)
|
|
|
|
type consumer struct {
|
|
id string
|
|
conn net.Conn
|
|
}
|
|
|
|
func (p *consumer) send(ctx context.Context, cmd any) error {
|
|
return nil
|
|
}
|
|
|
|
type publisher struct {
|
|
id string
|
|
conn net.Conn
|
|
}
|
|
|
|
func (p *publisher) send(ctx context.Context, cmd any) error {
|
|
return nil
|
|
}
|
|
|
|
type Handler func(context.Context, Task) Result
|
|
|
|
type Broker struct {
|
|
queues xsync.IMap[string, *Queue]
|
|
consumers xsync.IMap[string, *consumer]
|
|
publishers xsync.IMap[string, *publisher]
|
|
opts Options
|
|
}
|
|
|
|
type Queue struct {
|
|
name string
|
|
consumers xsync.IMap[string, *consumer]
|
|
messages xsync.IMap[string, *Task]
|
|
deferred xsync.IMap[string, *Task]
|
|
}
|
|
|
|
func newQueue(name string) *Queue {
|
|
return &Queue{
|
|
name: name,
|
|
consumers: xsync.NewMap[string, *consumer](),
|
|
messages: xsync.NewMap[string, *Task](),
|
|
deferred: xsync.NewMap[string, *Task](),
|
|
}
|
|
}
|
|
|
|
func (queue *Queue) send(ctx context.Context, cmd any) {
|
|
queue.consumers.ForEach(func(_ string, client *consumer) bool {
|
|
err := client.send(ctx, cmd)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
type Task struct {
|
|
ID string `json:"id"`
|
|
Payload json.RawMessage `json:"payload"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
ProcessedAt time.Time `json:"processed_at"`
|
|
Status string `json:"status"`
|
|
Error error `json:"error"`
|
|
}
|
|
|
|
func NewBroker(opts ...Option) *Broker {
|
|
options := defaultOptions()
|
|
for _, opt := range opts {
|
|
opt(&options)
|
|
}
|
|
b := &Broker{
|
|
queues: xsync.NewMap[string, *Queue](),
|
|
publishers: xsync.NewMap[string, *publisher](),
|
|
consumers: xsync.NewMap[string, *consumer](),
|
|
opts: options,
|
|
}
|
|
return b
|
|
}
|
|
|
|
func (b *Broker) TLSConfig() TLSConfig {
|
|
return b.opts.tlsConfig
|
|
}
|
|
|
|
func (b *Broker) SyncMode() bool {
|
|
return b.opts.syncMode
|
|
}
|
|
|
|
func (b *Broker) OnClose(ctx context.Context, _ net.Conn) error {
|
|
consumerID, ok := GetConsumerID(ctx)
|
|
if ok && consumerID != "" {
|
|
if con, exists := b.consumers.Get(consumerID); exists {
|
|
con.conn.Close()
|
|
b.consumers.Del(consumerID)
|
|
}
|
|
b.queues.ForEach(func(_ string, queue *Queue) bool {
|
|
queue.consumers.Del(consumerID)
|
|
return true
|
|
})
|
|
}
|
|
publisherID, ok := GetPublisherID(ctx)
|
|
if ok && publisherID != "" {
|
|
if con, exists := b.publishers.Get(publisherID); exists {
|
|
con.conn.Close()
|
|
b.publishers.Del(publisherID)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
|
fmt.Println("Error reading from connection:", err, conn.RemoteAddr())
|
|
}
|
|
|
|
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
|
switch msg.Command {
|
|
case consts.PUBLISH:
|
|
b.publish(ctx, conn, msg)
|
|
case consts.SUBSCRIBE:
|
|
b.subscribe(ctx, conn, msg)
|
|
case consts.MESSAGE_RESPONSE:
|
|
headers, ok := GetHeaders(ctx)
|
|
if ok {
|
|
if awaitResponse, ok := headers[consts.AwaitResponseKey]; ok && awaitResponse == "true" {
|
|
publisherID, exists := headers[consts.PublisherKey]
|
|
if exists {
|
|
con, ok := b.publishers.Get(publisherID)
|
|
if ok {
|
|
msg.Command = consts.RESPONSE
|
|
err := codec.SendMessage(con.conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
fmt.Println("consumer confirmed", headers, ok, string(msg.Payload))
|
|
case consts.MESSAGE_ACK:
|
|
fmt.Println("consumer confirmed")
|
|
}
|
|
}
|
|
|
|
func (b *Broker) readMessage(ctx context.Context, c net.Conn) error {
|
|
msg, err := codec.ReadMessage(c, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
|
if err == nil {
|
|
ctx = SetHeaders(ctx, msg.Headers)
|
|
b.OnMessage(ctx, msg, c)
|
|
return nil
|
|
}
|
|
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
|
b.OnClose(ctx, c)
|
|
return err
|
|
}
|
|
b.OnError(ctx, c, err)
|
|
return err
|
|
}
|
|
|
|
// Start the broker server with optional TLS support
|
|
func (b *Broker) Start(ctx context.Context) error {
|
|
var listener net.Listener
|
|
var err error
|
|
|
|
if b.opts.tlsConfig.UseTLS {
|
|
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to load TLS certificates: %v", err)
|
|
}
|
|
tlsConfig := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start TLS listener: %v", err)
|
|
}
|
|
log.Println("TLS server started on", b.opts.brokerAddr)
|
|
} else {
|
|
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start TCP listener: %v", err)
|
|
}
|
|
log.Println("TCP server started on", b.opts.brokerAddr)
|
|
}
|
|
defer listener.Close()
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
b.OnError(ctx, conn, err)
|
|
continue
|
|
}
|
|
go func(c net.Conn) {
|
|
defer c.Close()
|
|
for {
|
|
err := b.readMessage(ctx, c)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
}(conn)
|
|
}
|
|
}
|
|
|
|
func (b *Broker) NewQueue(qName string) *Queue {
|
|
q, ok := b.queues.Get(qName)
|
|
if ok {
|
|
return q
|
|
}
|
|
q = newQueue(qName)
|
|
b.queues.Set(qName, q)
|
|
return q
|
|
}
|
|
|
|
func (b *Broker) AddMessageToQueue(task *Task, queueName string) (*Queue, *Task, error) {
|
|
queue := b.NewQueue(queueName)
|
|
if task.ID == "" {
|
|
task.ID = NewID()
|
|
}
|
|
task.CreatedAt = time.Now()
|
|
queue.messages.Set(task.ID, task)
|
|
return queue, task, nil
|
|
}
|
|
|
|
func (b *Broker) addConsumer(ctx context.Context, queueName string, conn net.Conn) string {
|
|
consumerID, ok := GetConsumerID(ctx)
|
|
q, ok := b.queues.Get(queueName)
|
|
if !ok {
|
|
q = b.NewQueue(queueName)
|
|
}
|
|
con := &consumer{id: consumerID, conn: conn}
|
|
b.consumers.Set(consumerID, con)
|
|
q.consumers.Set(consumerID, con)
|
|
log.Printf("Consumer %s joined server on queue %s", consumerID, queueName)
|
|
return consumerID
|
|
}
|
|
|
|
func (b *Broker) addPublisher(ctx context.Context, queueName string, conn net.Conn) *publisher {
|
|
publisherID, ok := GetPublisherID(ctx)
|
|
_, ok = b.queues.Get(queueName)
|
|
if !ok {
|
|
b.NewQueue(queueName)
|
|
}
|
|
con := &publisher{id: publisherID, conn: conn}
|
|
b.publishers.Set(publisherID, con)
|
|
return con
|
|
}
|
|
|
|
func (b *Broker) broadcastToConsumers(ctx context.Context, msg *codec.Message) {
|
|
if queue, ok := b.queues.Get(msg.Queue); ok {
|
|
queue.consumers.ForEach(func(_ string, con *consumer) bool {
|
|
msg.Command = consts.MESSAGE_SEND
|
|
if err := codec.SendMessage(con.conn, msg, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil {
|
|
log.Printf("Error sending Message: %v\n", err)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}
|
|
|
|
func (b *Broker) waitForAck(conn net.Conn) error {
|
|
msg, err := codec.ReadMessage(conn, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if msg.Command == consts.MESSAGE_ACK {
|
|
log.Println("Received CONSUMER_ACK: Subscribed successfully")
|
|
return nil
|
|
}
|
|
return fmt.Errorf("expected CONSUMER_ACK, got: %v", msg.Command)
|
|
}
|
|
|
|
func (b *Broker) publish(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
|
pub := b.addPublisher(ctx, msg.Queue, conn)
|
|
ack := &codec.Message{
|
|
Headers: msg.Headers,
|
|
Queue: msg.Queue,
|
|
Command: consts.PUBLISH_ACK,
|
|
}
|
|
if err := codec.SendMessage(conn, ack, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil {
|
|
log.Printf("Error sending PUBLISH_ACK: %v\n", err)
|
|
}
|
|
b.broadcastToConsumers(ctx, msg)
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
b.publishers.Del(pub.id)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (b *Broker) subscribe(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
|
consumerID := b.addConsumer(ctx, msg.Queue, conn)
|
|
ack := &codec.Message{
|
|
Headers: msg.Headers,
|
|
Queue: msg.Queue,
|
|
Command: consts.SUBSCRIBE_ACK,
|
|
}
|
|
if err := codec.SendMessage(conn, ack, b.opts.aesKey, b.opts.hmacKey, b.opts.enableEncryption); err != nil {
|
|
log.Printf("Error sending SUBSCRIBE_ACK: %v\n", err)
|
|
}
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
b.removeConsumer(msg.Queue, consumerID)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Removes connection from the queue and broker
|
|
func (b *Broker) removeConsumer(queueName, consumerID string) {
|
|
if queue, ok := b.queues.Get(queueName); ok {
|
|
con, ok := queue.consumers.Get(consumerID)
|
|
if ok {
|
|
con.conn.Close()
|
|
queue.consumers.Del(consumerID)
|
|
}
|
|
b.queues.Del(queueName)
|
|
}
|
|
}
|