diff --git a/ctx.go b/ctx.go index 2ca0b9c..04a2e8c 100644 --- a/ctx.go +++ b/ctx.go @@ -148,7 +148,7 @@ func GetConnection(addr string, config TLSConfig) (net.Conn, error) { return nil, err } // Store the new connection in the pool. - connPool.Store(key, conn) + // connPool.Store(key, conn) // Disable pooling for now return conn, nil } diff --git a/examples/publisher.go b/examples/publisher.go index e2841b0..5a438bf 100644 --- a/examples/publisher.go +++ b/examples/publisher.go @@ -20,10 +20,11 @@ func main() { ) for i := 0; i < 2; i++ { // publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key")) - err := publisher.Publish(context.Background(), task, "queue1") - if err != nil { - panic(err) + result := publisher.Request(context.Background(), task, "queue1") + if result.Error != nil { + panic(result.Error) } + fmt.Println(string(result.Payload)) } fmt.Println("Async task published successfully") } diff --git a/examples/server.go b/examples/server.go index 6496658..9cd7778 100644 --- a/examples/server.go +++ b/examples/server.go @@ -24,7 +24,7 @@ func main() { } b.NewQueue("queue1") b.NewQueue("queue2") - b.StartEnhanced(context.Background()) + b.Start(context.Background()) } // InitializeDefaults adds default permissions, roles, and users for development/testing diff --git a/mq.go b/mq.go index c0ebd04..b3b8030 100644 --- a/mq.go +++ b/mq.go @@ -2,7 +2,6 @@ package mq import ( "context" - "crypto/tls" "fmt" "log" "net" @@ -481,7 +480,8 @@ type Broker struct { securityManager *SecurityManager adminServer *AdminServer metricsServer *MetricsServer - authenticatedConns storage.IMap[string, bool] // authenticated connections + authenticatedConns storage.IMap[string, bool] // authenticated connections + taskHeaders storage.IMap[string, map[string]string] // task headers by task ID isShutdown int32 shutdown chan struct{} wg sync.WaitGroup @@ -507,6 +507,7 @@ func NewBroker(opts ...Option) *Broker { metricsCollector: NewMetricsCollector(), messageStore: NewInMemoryMessageStore(), authenticatedConns: memory.New[string, bool](), + taskHeaders: memory.New[string, map[string]string](), shutdown: make(chan struct{}), logger: options.Logger(), } @@ -623,6 +624,9 @@ func (b *Broker) isAuthenticated(connID string) bool { } func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + // Set message headers in context for publisher/consumer ID extraction + ctx = SetHeaders(ctx, msg.Headers) + connID := conn.RemoteAddr().String() // Check authentication for protected commands @@ -783,6 +787,21 @@ func (b *Broker) OnConsumerResume(ctx context.Context, _ *codec.Message) { } func (b *Broker) MessageResponseHandler(ctx context.Context, msg *codec.Message) { + // Extract task ID from response + var result Result + if err := json.Unmarshal(msg.Payload, &result); err != nil { + log.Printf("Error unmarshaling response: %v", err) + return + } + taskID := result.TaskID + + // Retrieve stored headers for this task + if headers, ok := b.taskHeaders.Get(taskID); ok { + ctx = SetHeaders(ctx, headers) + // Clean up stored headers + b.taskHeaders.Del(taskID) + } + msg.Command = consts.RESPONSE b.HandleCallback(ctx, msg) awaitResponse, ok := GetAwaitResponse(ctx) @@ -822,6 +841,9 @@ func (b *Broker) PublishHandler(ctx context.Context, conn net.Conn, msg *codec.M taskID, _ := jsonparser.GetString(msg.Payload, "id") log.Printf("BROKER - PUBLISH ~> received from %s on %s for Task %s", pub.id, msg.Queue, taskID) + // Store headers for response routing + b.taskHeaders.Set(taskID, msg.Headers) + ack, err := codec.NewMessage(consts.PUBLISH_ACK, utils.ToByte(fmt.Sprintf(`{"id":"%s"}`, taskID)), msg.Queue, msg.Headers) if err != nil { log.Printf("Error creating PUBLISH_ACK message: %v\n", err) @@ -861,110 +883,58 @@ func (b *Broker) SubscribeHandler(ctx context.Context, conn net.Conn, msg *codec } func (b *Broker) Start(ctx context.Context) error { - var listener net.Listener - var err error - if b.opts.tlsConfig.UseTLS { - cert, err := tls.LoadX509KeyPair(b.opts.tlsConfig.CertPath, b.opts.tlsConfig.KeyPath) - if err != nil { - return WrapError(err, "failed to load TLS certificates for broker", "BROKER_TLS_CERT_ERROR") - } - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - listener, err = tls.Listen("tcp", b.opts.brokerAddr, tlsConfig) - if err != nil { - return WrapError(err, "TLS broker failed to listen on "+b.opts.brokerAddr, "BROKER_TLS_LISTEN_ERROR") - } - } else { - listener, err = net.Listen("tcp", b.opts.brokerAddr) - if err != nil { - return WrapError(err, "broker failed to listen on "+b.opts.brokerAddr, "BROKER_LISTEN_ERROR") - } - } - b.listener = listener - defer b.Close() - const maxConcurrentConnections = 100 - sem := make(chan struct{}, maxConcurrentConnections) - for { - select { - case <-ctx.Done(): - log.Printf("BROKER - Shutdown signal received") - return ctx.Err() - default: - conn, err := listener.Accept() - if err != nil { - if atomic.LoadInt32(&b.isShutdown) == 1 { - return nil - } - log.Printf("BROKER - Error accepting connection: %v", err) - continue + // Start health checker + b.healthChecker.Start() + + // Start connection cleanup routine + b.wg.Add(1) + go b.connectionCleanupRoutine() + + // Start metrics collection routine + b.wg.Add(1) + go b.metricsCollectionRoutine() + + // Start message store cleanup routine + b.wg.Add(1) + go b.messageStoreCleanupRoutine() + + // Start admin server if enabled + if b.adminServer != nil { + b.wg.Add(1) + go func() { + defer b.wg.Done() + if err := b.adminServer.Start(); err != nil { + b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()}) } - - // Configure connection for broker-consumer communication with NO timeouts - if tcpConn, ok := conn.(*net.TCPConn); ok { - // Enable TCP keep-alive for all connections - tcpConn.SetKeepAlive(true) - tcpConn.SetKeepAlivePeriod(30 * time.Second) - - // NEVER set any deadlines for broker-consumer connections - // These connections must remain open indefinitely for persistent communication - // DO NOT call: tcpConn.SetReadDeadline() or tcpConn.SetWriteDeadline() - - log.Printf("BROKER - TCP keep-alive enabled for connection from %s (NO timeouts)", conn.RemoteAddr()) - } - - sem <- struct{}{} - go func() { - defer func() { <-sem }() - defer conn.Close() - b.handleConnection(ctx, conn) - }() - } + }() } -} -// handleConnection handles a single connection with NO timeouts for persistent broker-consumer communication -func (b *Broker) handleConnection(ctx context.Context, conn net.Conn) { - defer func() { - if r := recover(); r != nil { - b.logger.Error("Connection handler panic", - logger.Field{Key: "panic", Value: fmt.Sprintf("%v", r)}, - logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()}) - } - conn.Close() - }() - - // CRITICAL: Never set any timeouts on broker-consumer connections - // These connections must remain open indefinitely for persistent communication - - for { - select { - case <-ctx.Done(): - b.logger.Debug("Context cancelled, closing connection", - logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()}) - return - default: - // Read message WITHOUT any timeout - this is crucial for persistent connections - if err := b.readMessage(ctx, conn); err != nil { - if err.Error() == "EOF" || strings.Contains(err.Error(), "closed network connection") { - b.logger.Debug("Connection closed by client", - logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()}) - return - } - // Don't return on timeout errors - they should not occur since we don't set timeouts - if strings.Contains(err.Error(), "timeout") { - b.logger.Warn("Unexpected timeout on connection (should not happen)", - logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()}, - logger.Field{Key: "error", Value: err.Error()}) - continue - } - b.logger.Error("Connection error", - logger.Field{Key: "remote_addr", Value: conn.RemoteAddr().String()}, - logger.Field{Key: "error", Value: err.Error()}) - return + // Start metrics server if enabled + if b.metricsServer != nil { + b.wg.Add(1) + go func() { + defer b.wg.Done() + if err := b.metricsServer.Start(ctx); err != nil { + b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()}) } - } + }() } + + b.logger.Info("Broker starting with production features enabled") + + // Start the broker with enhanced features + if err := b.startBroker(ctx); err != nil { + return err + } + + // Wait for shutdown signal + <-b.shutdown + b.logger.Info("Broker shutting down") + + // Wait for all goroutines to finish + b.wg.Wait() + + return nil } func (b *Broker) send(ctx context.Context, conn net.Conn, msg *codec.Message) error { @@ -1602,63 +1572,9 @@ func (ims *InMemoryMessageStore) Cleanup(olderThan time.Time) error { } // Enhanced Start method with production features -func (b *Broker) StartEnhanced(ctx context.Context) error { - // Start health checker - b.healthChecker.Start() - // Start connection cleanup routine - b.wg.Add(1) - go b.connectionCleanupRoutine() - - // Start metrics collection routine - b.wg.Add(1) - go b.metricsCollectionRoutine() - - // Start message store cleanup routine - b.wg.Add(1) - go b.messageStoreCleanupRoutine() - - // Start admin server if enabled - if b.adminServer != nil { - b.wg.Add(1) - go func() { - defer b.wg.Done() - if err := b.adminServer.Start(); err != nil { - b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()}) - } - }() - } - - // Start metrics server if enabled - if b.metricsServer != nil { - b.wg.Add(1) - go func() { - defer b.wg.Done() - if err := b.metricsServer.Start(ctx); err != nil { - b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()}) - } - }() - } - - b.logger.Info("Enhanced broker starting with production features enabled") - - // Start the enhanced broker with its own implementation - if err := b.startEnhancedBroker(ctx); err != nil { - return err - } - - // Wait for shutdown signal - <-b.shutdown - b.logger.Info("Enhanced broker shutting down") - - // Wait for all goroutines to finish - b.wg.Wait() - - return nil -} - -// startEnhancedBroker starts the core broker functionality -func (b *Broker) startEnhancedBroker(ctx context.Context) error { +// startBroker starts the core broker functionality +func (b *Broker) startBroker(ctx context.Context) error { addr := b.opts.BrokerAddr() listener, err := net.Listen("tcp", addr) if err != nil { diff --git a/publisher.go b/publisher.go index 77f51ac..36fd5d1 100644 --- a/publisher.go +++ b/publisher.go @@ -229,6 +229,46 @@ func (p *Publisher) Request(ctx context.Context, task Task, queue string) Result defer func() { _ = conn.Close() }() + + // Authenticate if security is enabled + if p.opts.enableSecurity { + if p.opts.username == "" || p.opts.password == "" { + return Result{Error: fmt.Errorf("username and password required for authentication")} + } + + authPayload := map[string]string{ + "username": p.opts.username, + "password": p.opts.password, + } + payload, err := json.Marshal(authPayload) + if err != nil { + return Result{Error: err} + } + + headers := map[string]string{ + consts.PublisherKey: p.id, + consts.ContentType: consts.TypeJson, + } + msg, err := codec.NewMessage(consts.AUTH, payload, "", headers) + if err != nil { + return Result{Error: err} + } + + err = codec.SendMessage(ctx, conn, msg) + if err != nil { + return Result{Error: err} + } + + // Wait for AUTH_ACK + resp, err := codec.ReadMessage(ctx, conn) + if err != nil { + return Result{Error: fmt.Errorf("authentication failed: %w", err)} + } + if resp.Command != consts.AUTH_ACK { + return Result{Error: fmt.Errorf("authentication failed: %s", string(resp.Payload))} + } + } + err = p.send(ctx, queue, task, conn, consts.PUBLISH) resultCh := make(chan Result) go func() {