mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-26 20:11:16 +08:00
update
This commit is contained in:
2
ctx.go
2
ctx.go
@@ -148,7 +148,7 @@ func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
|
||||
return nil, err
|
||||
}
|
||||
// Store the new connection in the pool.
|
||||
connPool.Store(key, conn)
|
||||
// connPool.Store(key, conn) // Disable pooling for now
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
|
@@ -20,10 +20,11 @@ func main() {
|
||||
)
|
||||
for i := 0; i < 2; i++ {
|
||||
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
||||
err := publisher.Publish(context.Background(), task, "queue1")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
result := publisher.Request(context.Background(), task, "queue1")
|
||||
if result.Error != nil {
|
||||
panic(result.Error)
|
||||
}
|
||||
fmt.Println(string(result.Payload))
|
||||
}
|
||||
fmt.Println("Async task published successfully")
|
||||
}
|
||||
|
@@ -24,7 +24,7 @@ func main() {
|
||||
}
|
||||
b.NewQueue("queue1")
|
||||
b.NewQueue("queue2")
|
||||
b.StartEnhanced(context.Background())
|
||||
b.Start(context.Background())
|
||||
}
|
||||
|
||||
// InitializeDefaults adds default permissions, roles, and users for development/testing
|
||||
|
230
mq.go
230
mq.go
@@ -2,7 +2,6 @@ package mq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
@@ -481,7 +480,8 @@ type Broker struct {
|
||||
securityManager *SecurityManager
|
||||
adminServer *AdminServer
|
||||
metricsServer *MetricsServer
|
||||
authenticatedConns storage.IMap[string, bool] // authenticated connections
|
||||
authenticatedConns storage.IMap[string, bool] // authenticated connections
|
||||
taskHeaders storage.IMap[string, map[string]string] // task headers by task ID
|
||||
isShutdown int32
|
||||
shutdown chan struct{}
|
||||
wg sync.WaitGroup
|
||||
@@ -507,6 +507,7 @@ func NewBroker(opts ...Option) *Broker {
|
||||
metricsCollector: NewMetricsCollector(),
|
||||
messageStore: NewInMemoryMessageStore(),
|
||||
authenticatedConns: memory.New[string, bool](),
|
||||
taskHeaders: memory.New[string, map[string]string](),
|
||||
shutdown: make(chan struct{}),
|
||||
logger: options.Logger(),
|
||||
}
|
||||
@@ -623,6 +624,9 @@ func (b *Broker) isAuthenticated(connID string) bool {
|
||||
}
|
||||
|
||||
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
// Set message headers in context for publisher/consumer ID extraction
|
||||
ctx = SetHeaders(ctx, msg.Headers)
|
||||
|
||||
connID := conn.RemoteAddr().String()
|
||||
|
||||
// Check authentication for protected commands
|
||||
@@ -783,6 +787,21 @@ func (b *Broker) OnConsumerResume(ctx context.Context, _ *codec.Message) {
|
||||
}
|
||||
|
||||
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
|
||||
// Extract task ID from response
|
||||
var result Result
|
||||
if err := json.Unmarshal(msg.Payload, &result); err != nil {
|
||||
log.Printf("Error unmarshaling response: %v", err)
|
||||
return
|
||||
}
|
||||
taskID := result.TaskID
|
||||
|
||||
// Retrieve stored headers for this task
|
||||
if headers, ok := b.taskHeaders.Get(taskID); ok {
|
||||
ctx = SetHeaders(ctx, headers)
|
||||
// Clean up stored headers
|
||||
b.taskHeaders.Del(taskID)
|
||||
}
|
||||
|
||||
msg.Command = consts.RESPONSE
|
||||
b.HandleCallback(ctx, msg)
|
||||
awaitResponse, ok := GetAwaitResponse(ctx)
|
||||
@@ -822,6 +841,9 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
|
||||
taskID, _ := jsonparser.GetString(msg.Payload, "id")
|
||||
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
|
||||
|
||||
// Store headers for response routing
|
||||
b.taskHeaders.Set(taskID, msg.Headers)
|
||||
|
||||
ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
|
||||
if err != nil {
|
||||
log.Printf("Error creating PUBLISH_ACK message: %v\n", err)
|
||||
@@ -861,110 +883,58 @@ func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec
|
||||
}
|
||||
|
||||
func (b *Broker) Start(ctx context.Context) error {
|
||||
var listener net.Listener
|
||||
var err error
|
||||
if b.opts.tlsConfig.UseTLS {
|
||||
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
|
||||
if err != nil {
|
||||
return WrapError(err, "failed to load TLS certificates for broker", "BROKER_TLS_CERT_ERROR")
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
|
||||
if err != nil {
|
||||
return WrapError(err, "TLS broker failed to listen on "+b.opts.brokerAddr, "BROKER_TLS_LISTEN_ERROR")
|
||||
}
|
||||
} else {
|
||||
listener, err = net.Listen("tcp", b.opts.brokerAddr)
|
||||
if err != nil {
|
||||
return WrapError(err, "broker failed to listen on "+b.opts.brokerAddr, "BROKER_LISTEN_ERROR")
|
||||
}
|
||||
}
|
||||
b.listener = listener
|
||||
defer b.Close()
|
||||
const maxConcurrentConnections = 100
|
||||
sem := make(chan struct{}, maxConcurrentConnections)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Printf("BROKER - Shutdown signal received")
|
||||
return ctx.Err()
|
||||
default:
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
if atomic.LoadInt32(&b.isShutdown) == 1 {
|
||||
return nil
|
||||
}
|
||||
log.Printf("BROKER - Error accepting connection: %v", err)
|
||||
continue
|
||||
// Start health checker
|
||||
b.healthChecker.Start()
|
||||
|
||||
// Start connection cleanup routine
|
||||
b.wg.Add(1)
|
||||
go b.connectionCleanupRoutine()
|
||||
|
||||
// Start metrics collection routine
|
||||
b.wg.Add(1)
|
||||
go b.metricsCollectionRoutine()
|
||||
|
||||
// Start message store cleanup routine
|
||||
b.wg.Add(1)
|
||||
go b.messageStoreCleanupRoutine()
|
||||
|
||||
// Start admin server if enabled
|
||||
if b.adminServer != nil {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if err := b.adminServer.Start(); err != nil {
|
||||
b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
|
||||
// Configure connection for broker-consumer communication with NO timeouts
|
||||
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
||||
// Enable TCP keep-alive for all connections
|
||||
tcpConn.SetKeepAlive(true)
|
||||
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
|
||||
// NEVER set any deadlines for broker-consumer connections
|
||||
// These connections must remain open indefinitely for persistent communication
|
||||
// DO NOT call: tcpConn.SetReadDeadline() or tcpConn.SetWriteDeadline()
|
||||
|
||||
log.Printf("BROKER - TCP keep-alive enabled for connection from %s (NO timeouts)", conn.RemoteAddr())
|
||||
}
|
||||
|
||||
sem <- struct{}{}
|
||||
go func() {
|
||||
defer func() { <-sem }()
|
||||
defer conn.Close()
|
||||
b.handleConnection(ctx, conn)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single connection with NO timeouts for persistent broker-consumer communication
|
||||
func (b *Broker) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
b.logger.Error("Connection handler panic",
|
||||
logger.Field{Key: "panic", Value: fmt.Sprintf("%v", r)},
|
||||
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()})
|
||||
}
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
// CRITICAL: Never set any timeouts on broker-consumer connections
|
||||
// These connections must remain open indefinitely for persistent communication
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
b.logger.Debug("Context cancelled, closing connection",
|
||||
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()})
|
||||
return
|
||||
default:
|
||||
// Read message WITHOUT any timeout - this is crucial for persistent connections
|
||||
if err := b.readMessage(ctx, conn); err != nil {
|
||||
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
|
||||
b.logger.Debug("Connection closed by client",
|
||||
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()})
|
||||
return
|
||||
}
|
||||
// Don't return on timeout errors - they should not occur since we don't set timeouts
|
||||
if strings.Contains(err.Error(), "timeout") {
|
||||
b.logger.Warn("Unexpected timeout on connection (should not happen)",
|
||||
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()},
|
||||
logger.Field{Key: "error", Value: err.Error()})
|
||||
continue
|
||||
}
|
||||
b.logger.Error("Connection error",
|
||||
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()},
|
||||
logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
// Start metrics server if enabled
|
||||
if b.metricsServer != nil {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if err := b.metricsServer.Start(ctx); err != nil {
|
||||
b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
b.logger.Info("Broker starting with production features enabled")
|
||||
|
||||
// Start the broker with enhanced features
|
||||
if err := b.startBroker(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for shutdown signal
|
||||
<-b.shutdown
|
||||
b.logger.Info("Broker shutting down")
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
b.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
|
||||
@@ -1602,63 +1572,9 @@ func (ims *InMemoryMessageStore) Cleanup(olderThan time.Time) error {
|
||||
}
|
||||
|
||||
// Enhanced Start method with production features
|
||||
func (b *Broker) StartEnhanced(ctx context.Context) error {
|
||||
// Start health checker
|
||||
b.healthChecker.Start()
|
||||
|
||||
// Start connection cleanup routine
|
||||
b.wg.Add(1)
|
||||
go b.connectionCleanupRoutine()
|
||||
|
||||
// Start metrics collection routine
|
||||
b.wg.Add(1)
|
||||
go b.metricsCollectionRoutine()
|
||||
|
||||
// Start message store cleanup routine
|
||||
b.wg.Add(1)
|
||||
go b.messageStoreCleanupRoutine()
|
||||
|
||||
// Start admin server if enabled
|
||||
if b.adminServer != nil {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if err := b.adminServer.Start(); err != nil {
|
||||
b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start metrics server if enabled
|
||||
if b.metricsServer != nil {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if err := b.metricsServer.Start(ctx); err != nil {
|
||||
b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
b.logger.Info("Enhanced broker starting with production features enabled")
|
||||
|
||||
// Start the enhanced broker with its own implementation
|
||||
if err := b.startEnhancedBroker(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for shutdown signal
|
||||
<-b.shutdown
|
||||
b.logger.Info("Enhanced broker shutting down")
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
b.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startEnhancedBroker starts the core broker functionality
|
||||
func (b *Broker) startEnhancedBroker(ctx context.Context) error {
|
||||
// startBroker starts the core broker functionality
|
||||
func (b *Broker) startBroker(ctx context.Context) error {
|
||||
addr := b.opts.BrokerAddr()
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
|
40
publisher.go
40
publisher.go
@@ -229,6 +229,46 @@ func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// Authenticate if security is enabled
|
||||
if p.opts.enableSecurity {
|
||||
if p.opts.username == "" || p.opts.password == "" {
|
||||
return Result{Error: fmt.Errorf("username and password required for authentication")}
|
||||
}
|
||||
|
||||
authPayload := map[string]string{
|
||||
"username": p.opts.username,
|
||||
"password": p.opts.password,
|
||||
}
|
||||
payload, err := json.Marshal(authPayload)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
consts.PublisherKey: p.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
}
|
||||
msg, err := codec.NewMessage(consts.AUTH, payload, "", headers)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
}
|
||||
|
||||
err = codec.SendMessage(ctx, conn, msg)
|
||||
if err != nil {
|
||||
return Result{Error: err}
|
||||
}
|
||||
|
||||
// Wait for AUTH_ACK
|
||||
resp, err := codec.ReadMessage(ctx, conn)
|
||||
if err != nil {
|
||||
return Result{Error: fmt.Errorf("authentication failed: %w", err)}
|
||||
}
|
||||
if resp.Command != consts.AUTH_ACK {
|
||||
return Result{Error: fmt.Errorf("authentication failed: %s", string(resp.Payload))}
|
||||
}
|
||||
}
|
||||
|
||||
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
|
||||
resultCh := make(chan Result)
|
||||
go func() {
|
||||
|
Reference in New Issue
Block a user