This commit is contained in:
sujit
2025-09-24 16:19:36 +05:45
parent ca422e5fda
commit 43f315e07c
5 changed files with 119 additions and 162 deletions

2
ctx.go
View File

@@ -148,7 +148,7 @@ func GetConnection(addr string, config TLSConfig) (net.Conn, error) {
return nil, err
}
// Store the new connection in the pool.
connPool.Store(key, conn)
// connPool.Store(key, conn) // Disable pooling for now
return conn, nil
}

View File

@@ -20,10 +20,11 @@ func main() {
)
for i := 0; i < 2; i++ {
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
err := publisher.Publish(context.Background(), task, "queue1")
if err != nil {
panic(err)
result := publisher.Request(context.Background(), task, "queue1")
if result.Error != nil {
panic(result.Error)
}
fmt.Println(string(result.Payload))
}
fmt.Println("Async task published successfully")
}

View File

@@ -24,7 +24,7 @@ func main() {
}
b.NewQueue("queue1")
b.NewQueue("queue2")
b.StartEnhanced(context.Background())
b.Start(context.Background())
}
// InitializeDefaults adds default permissions, roles, and users for development/testing

230
mq.go
View File

@@ -2,7 +2,6 @@ package mq
import (
"context"
"crypto/tls"
"fmt"
"log"
"net"
@@ -481,7 +480,8 @@ type Broker struct {
securityManager *SecurityManager
adminServer *AdminServer
metricsServer *MetricsServer
authenticatedConns storage.IMap[string, bool] // authenticated connections
authenticatedConns storage.IMap[string, bool] // authenticated connections
taskHeaders storage.IMap[string, map[string]string] // task headers by task ID
isShutdown int32
shutdown chan struct{}
wg sync.WaitGroup
@@ -507,6 +507,7 @@ func NewBroker(opts ...Option) *Broker {
metricsCollector: NewMetricsCollector(),
messageStore: NewInMemoryMessageStore(),
authenticatedConns: memory.New[string, bool](),
taskHeaders: memory.New[string, map[string]string](),
shutdown: make(chan struct{}),
logger: options.Logger(),
}
@@ -623,6 +624,9 @@ func (b *Broker) isAuthenticated(connID string) bool {
}
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
// Set message headers in context for publisher/consumer ID extraction
ctx = SetHeaders(ctx, msg.Headers)
connID := conn.RemoteAddr().String()
// Check authentication for protected commands
@@ -783,6 +787,21 @@ func (b *Broker) OnConsumerResume(ctx context.Context, _ *codec.Message) {
}
func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) {
// Extract task ID from response
var result Result
if err := json.Unmarshal(msg.Payload, &result); err != nil {
log.Printf("Error unmarshaling response: %v", err)
return
}
taskID := result.TaskID
// Retrieve stored headers for this task
if headers, ok := b.taskHeaders.Get(taskID); ok {
ctx = SetHeaders(ctx, headers)
// Clean up stored headers
b.taskHeaders.Del(taskID)
}
msg.Command = consts.RESPONSE
b.HandleCallback(ctx, msg)
awaitResponse, ok := GetAwaitResponse(ctx)
@@ -822,6 +841,9 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M
taskID, _ := jsonparser.GetString(msg.Payload, "id")
log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID)
// Store headers for response routing
b.taskHeaders.Set(taskID, msg.Headers)
ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers)
if err != nil {
log.Printf("Error creating PUBLISH_ACK message: %v\n", err)
@@ -861,110 +883,58 @@ func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec
}
func (b *Broker) Start(ctx context.Context) error {
var listener net.Listener
var err error
if b.opts.tlsConfig.UseTLS {
cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath)
if err != nil {
return WrapError(err, "failed to load TLS certificates for broker", "BROKER_TLS_CERT_ERROR")
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
}
listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig)
if err != nil {
return WrapError(err, "TLS broker failed to listen on "+b.opts.brokerAddr, "BROKER_TLS_LISTEN_ERROR")
}
} else {
listener, err = net.Listen("tcp", b.opts.brokerAddr)
if err != nil {
return WrapError(err, "broker failed to listen on "+b.opts.brokerAddr, "BROKER_LISTEN_ERROR")
}
}
b.listener = listener
defer b.Close()
const maxConcurrentConnections = 100
sem := make(chan struct{}, maxConcurrentConnections)
for {
select {
case <-ctx.Done():
log.Printf("BROKER - Shutdown signal received")
return ctx.Err()
default:
conn, err := listener.Accept()
if err != nil {
if atomic.LoadInt32(&b.isShutdown) == 1 {
return nil
}
log.Printf("BROKER - Error accepting connection: %v", err)
continue
// Start health checker
b.healthChecker.Start()
// Start connection cleanup routine
b.wg.Add(1)
go b.connectionCleanupRoutine()
// Start metrics collection routine
b.wg.Add(1)
go b.metricsCollectionRoutine()
// Start message store cleanup routine
b.wg.Add(1)
go b.messageStoreCleanupRoutine()
// Start admin server if enabled
if b.adminServer != nil {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if err := b.adminServer.Start(); err != nil {
b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()})
}
// Configure connection for broker-consumer communication with NO timeouts
if tcpConn, ok := conn.(*net.TCPConn); ok {
// Enable TCP keep-alive for all connections
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(30 * time.Second)
// NEVER set any deadlines for broker-consumer connections
// These connections must remain open indefinitely for persistent communication
// DO NOT call: tcpConn.SetReadDeadline() or tcpConn.SetWriteDeadline()
log.Printf("BROKER - TCP keep-alive enabled for connection from %s (NO timeouts)", conn.RemoteAddr())
}
sem <- struct{}{}
go func() {
defer func() { <-sem }()
defer conn.Close()
b.handleConnection(ctx, conn)
}()
}
}()
}
}
// handleConnection handles a single connection with NO timeouts for persistent broker-consumer communication
func (b *Broker) handleConnection(ctx context.Context, conn net.Conn) {
defer func() {
if r := recover(); r != nil {
b.logger.Error("Connection handler panic",
logger.Field{Key: "panic", Value: fmt.Sprintf("%v", r)},
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()})
}
conn.Close()
}()
// CRITICAL: Never set any timeouts on broker-consumer connections
// These connections must remain open indefinitely for persistent communication
for {
select {
case <-ctx.Done():
b.logger.Debug("Context cancelled, closing connection",
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()})
return
default:
// Read message WITHOUT any timeout - this is crucial for persistent connections
if err := b.readMessage(ctx, conn); err != nil {
if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") {
b.logger.Debug("Connection closed by client",
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()})
return
}
// Don't return on timeout errors - they should not occur since we don't set timeouts
if strings.Contains(err.Error(), "timeout") {
b.logger.Warn("Unexpected timeout on connection (should not happen)",
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()},
logger.Field{Key: "error", Value: err.Error()})
continue
}
b.logger.Error("Connection error",
logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()},
logger.Field{Key: "error", Value: err.Error()})
return
// Start metrics server if enabled
if b.metricsServer != nil {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if err := b.metricsServer.Start(ctx); err != nil {
b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()})
}
}
}()
}
b.logger.Info("Broker starting with production features enabled")
// Start the broker with enhanced features
if err := b.startBroker(ctx); err != nil {
return err
}
// Wait for shutdown signal
<-b.shutdown
b.logger.Info("Broker shutting down")
// Wait for all goroutines to finish
b.wg.Wait()
return nil
}
func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error {
@@ -1602,63 +1572,9 @@ func (ims *InMemoryMessageStore) Cleanup(olderThan time.Time) error {
}
// Enhanced Start method with production features
func (b *Broker) StartEnhanced(ctx context.Context) error {
// Start health checker
b.healthChecker.Start()
// Start connection cleanup routine
b.wg.Add(1)
go b.connectionCleanupRoutine()
// Start metrics collection routine
b.wg.Add(1)
go b.metricsCollectionRoutine()
// Start message store cleanup routine
b.wg.Add(1)
go b.messageStoreCleanupRoutine()
// Start admin server if enabled
if b.adminServer != nil {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if err := b.adminServer.Start(); err != nil {
b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()})
}
}()
}
// Start metrics server if enabled
if b.metricsServer != nil {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if err := b.metricsServer.Start(ctx); err != nil {
b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()})
}
}()
}
b.logger.Info("Enhanced broker starting with production features enabled")
// Start the enhanced broker with its own implementation
if err := b.startEnhancedBroker(ctx); err != nil {
return err
}
// Wait for shutdown signal
<-b.shutdown
b.logger.Info("Enhanced broker shutting down")
// Wait for all goroutines to finish
b.wg.Wait()
return nil
}
// startEnhancedBroker starts the core broker functionality
func (b *Broker) startEnhancedBroker(ctx context.Context) error {
// startBroker starts the core broker functionality
func (b *Broker) startBroker(ctx context.Context) error {
addr := b.opts.BrokerAddr()
listener, err := net.Listen("tcp", addr)
if err != nil {

View File

@@ -229,6 +229,46 @@ func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result
defer func() {
_ = conn.Close()
}()
// Authenticate if security is enabled
if p.opts.enableSecurity {
if p.opts.username == "" || p.opts.password == "" {
return Result{Error: fmt.Errorf("username and password required for authentication")}
}
authPayload := map[string]string{
"username": p.opts.username,
"password": p.opts.password,
}
payload, err := json.Marshal(authPayload)
if err != nil {
return Result{Error: err}
}
headers := map[string]string{
consts.PublisherKey: p.id,
consts.ContentType: consts.TypeJson,
}
msg, err := codec.NewMessage(consts.AUTH, payload, "", headers)
if err != nil {
return Result{Error: err}
}
err = codec.SendMessage(ctx, conn, msg)
if err != nil {
return Result{Error: err}
}
// Wait for AUTH_ACK
resp, err := codec.ReadMessage(ctx, conn)
if err != nil {
return Result{Error: fmt.Errorf("authentication failed: %w", err)}
}
if resp.Command != consts.AUTH_ACK {
return Result{Error: fmt.Errorf("authentication failed: %s", string(resp.Payload))}
}
}
err = p.send(ctx, queue, task, conn, consts.PUBLISH)
resultCh := make(chan Result)
go func() {