This commit is contained in:
sujit
2025-09-24 16:01:07 +05:45
parent cb2869c98b
commit ca422e5fda
10 changed files with 826 additions and 708 deletions

187
mq.go
View File

@@ -280,6 +280,12 @@ type Options struct {
BrokerRateLimiter *RateLimiter // new field for broker rate limiting
ConsumerRateLimiter *RateLimiter // new field for consumer rate limiting
consumerTimeout time.Duration // timeout for consumer message processing (0 = no timeout)
adminAddr string // address for admin server
metricsAddr string // address for metrics server
enableSecurity bool // enable security features
enableMonitoring bool // enable monitoring features
username string // username for authentication
password string // password for authentication
}
func (o *Options) SetSyncMode(sync bool) {
@@ -467,15 +473,19 @@ type Broker struct {
listener net.Listener
// Enhanced production features
connectionPool *ConnectionPool
healthChecker *HealthChecker
circuitBreaker *EnhancedCircuitBreaker
metricsCollector *MetricsCollector
messageStore MessageStore
isShutdown int32
shutdown chan struct{}
wg sync.WaitGroup
logger logger.Logger
connectionPool *ConnectionPool
healthChecker *HealthChecker
circuitBreaker *EnhancedCircuitBreaker
metricsCollector *MetricsCollector
messageStore MessageStore
securityManager *SecurityManager
adminServer *AdminServer
metricsServer *MetricsServer
authenticatedConns storage.IMap[string, bool] // authenticated connections
isShutdown int32
shutdown chan struct{}
wg sync.WaitGroup
logger logger.Logger
}
func NewBroker(opts ...Option) *Broker {
@@ -491,13 +501,38 @@ func NewBroker(opts ...Option) *Broker {
opts: options,
// Enhanced production features
connectionPool: NewConnectionPool(1000), // max 1000 connections
healthChecker: NewHealthChecker(),
circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout
metricsCollector: NewMetricsCollector(),
messageStore: NewInMemoryMessageStore(),
shutdown: make(chan struct{}),
logger: options.Logger(),
connectionPool: NewConnectionPool(1000), // max 1000 connections
healthChecker: NewHealthChecker(),
circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout
metricsCollector: NewMetricsCollector(),
messageStore: NewInMemoryMessageStore(),
authenticatedConns: memory.New[string, bool](),
shutdown: make(chan struct{}),
logger: options.Logger(),
}
if options.enableSecurity {
broker.securityManager = NewSecurityManager()
}
if options.enableMonitoring {
if options.adminAddr != "" {
broker.adminServer = NewAdminServer(broker, options.adminAddr, options.Logger())
}
if options.metricsAddr != "" {
// Need to create MonitoringConfig, use default
config := &MonitoringConfig{
EnableMetrics: true,
MetricsPort: 9090, // default
MetricsPath: "/metrics",
EnableHealthCheck: true,
HealthCheckPort: 8080,
HealthCheckPath: "/health",
HealthCheckInterval: time.Minute,
EnableLogging: true,
LogLevel: "info",
}
broker.metricsServer = NewMetricsServer(broker, config, options.Logger())
}
}
broker.healthChecker.broker = broker
@@ -508,6 +543,18 @@ func (b *Broker) Options() *Options {
return b.opts
}
func (b *Broker) SecurityManager() *SecurityManager {
return b.securityManager
}
// InitializeSecurity initializes default users, roles, and permissions for development/testing
func (b *Broker) InitializeSecurity() error {
if b.securityManager == nil {
return fmt.Errorf("security manager not initialized")
}
return nil
}
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
consumerID, ok := GetConsumerID(ctx)
if ok && consumerID != "" {
@@ -553,6 +600,10 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
b.publishers.Del(publisherID)
}
}
// Remove from authenticated connections
connID := conn.RemoteAddr().String()
b.authenticatedConns.Del(connID)
log.Printf("BROKER - Connection closed: address %s", conn.RemoteAddr())
return nil
}
@@ -563,8 +614,29 @@ func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
}
}
func (b *Broker) isAuthenticated(connID string) bool {
if b.securityManager == nil {
return true // no security, allow all
}
_, ok := b.authenticatedConns.Get(connID)
return ok
}
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
connID := conn.RemoteAddr().String()
// Check authentication for protected commands
if b.securityManager != nil && (msg.Command == consts.PUBLISH || msg.Command == consts.SUBSCRIBE) {
if !b.isAuthenticated(connID) {
b.logger.Warn("Unauthenticated access attempt", logger.Field{Key: "command", Value: msg.Command.String()}, logger.Field{Key: "conn", Value: connID})
// Send error response
return
}
}
switch msg.Command {
case consts.AUTH:
b.AuthHandler(ctx, conn, msg)
case consts.PUBLISH:
b.PublishHandler(ctx, conn, msg)
case consts.SUBSCRIBE:
@@ -588,6 +660,54 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con
}
}
func (b *Broker) AuthHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
connID := conn.RemoteAddr().String()
// Parse auth credentials from payload
var authReq map[string]any
if err := json.Unmarshal(msg.Payload, &authReq); err != nil {
b.logger.Error("Invalid auth request", logger.Field{Key: "error", Value: err.Error()})
return
}
username, _ := authReq["username"].(string)
password, _ := authReq["password"].(string)
// Authenticate
user, err := b.securityManager.Authenticate(ctx, map[string]any{
"username": username,
"password": password,
})
if err != nil {
b.logger.Warn("Authentication failed", logger.Field{Key: "username", Value: username}, logger.Field{Key: "conn", Value: connID})
// Send AUTH_DENY
denyMsg, err := codec.NewMessage(consts.AUTH_DENY, []byte(fmt.Sprintf(`{"error":"%s"}`, err.Error())), "", msg.Headers)
if err != nil {
b.logger.Error("Failed to create AUTH_DENY message", logger.Field{Key: "error", Value: err.Error()})
return
}
if err := b.send(ctx, conn, denyMsg); err != nil {
b.logger.Error("Failed to send AUTH_DENY", logger.Field{Key: "error", Value: err.Error()})
}
return
}
// Mark as authenticated
b.authenticatedConns.Set(connID, true)
// Send AUTH_ACK
ackMsg, err := codec.NewMessage(consts.AUTH_ACK, []byte(`{"status":"authenticated"}`), "", msg.Headers)
if err != nil {
b.logger.Error("Failed to create AUTH_ACK message", logger.Field{Key: "error", Value: err.Error()})
return
}
if err := b.send(ctx, conn, ackMsg); err != nil {
b.logger.Error("Failed to send AUTH_ACK", logger.Field{Key: "error", Value: err.Error()})
}
b.logger.Info("User authenticated", logger.Field{Key: "username", Value: user.Username}, logger.Field{Key: "conn", Value: connID})
}
func (b *Broker) AdjustConsumerWorkers(noOfWorkers int, consumerID ...string) {
b.consumers.ForEach(func(_ string, c *consumer) bool {
return true
@@ -1498,10 +1618,43 @@ func (b *Broker) StartEnhanced(ctx context.Context) error {
b.wg.Add(1)
go b.messageStoreCleanupRoutine()
// Start admin server if enabled
if b.adminServer != nil {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if err := b.adminServer.Start(); err != nil {
b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()})
}
}()
}
// Start metrics server if enabled
if b.metricsServer != nil {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if err := b.metricsServer.Start(ctx); err != nil {
b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()})
}
}()
}
b.logger.Info("Enhanced broker starting with production features enabled")
// Start the enhanced broker with its own implementation
return b.startEnhancedBroker(ctx)
if err := b.startEnhancedBroker(ctx); err != nil {
return err
}
// Wait for shutdown signal
<-b.shutdown
b.logger.Info("Enhanced broker shutting down")
// Wait for all goroutines to finish
b.wg.Wait()
return nil
}
// startEnhancedBroker starts the core broker functionality