mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-27 04:15:52 +08:00
update
This commit is contained in:
187
mq.go
187
mq.go
@@ -280,6 +280,12 @@ type Options struct {
|
||||
BrokerRateLimiter *RateLimiter // new field for broker rate limiting
|
||||
ConsumerRateLimiter *RateLimiter // new field for consumer rate limiting
|
||||
consumerTimeout time.Duration // timeout for consumer message processing (0 = no timeout)
|
||||
adminAddr string // address for admin server
|
||||
metricsAddr string // address for metrics server
|
||||
enableSecurity bool // enable security features
|
||||
enableMonitoring bool // enable monitoring features
|
||||
username string // username for authentication
|
||||
password string // password for authentication
|
||||
}
|
||||
|
||||
func (o *Options) SetSyncMode(sync bool) {
|
||||
@@ -467,15 +473,19 @@ type Broker struct {
|
||||
listener net.Listener
|
||||
|
||||
// Enhanced production features
|
||||
connectionPool *ConnectionPool
|
||||
healthChecker *HealthChecker
|
||||
circuitBreaker *EnhancedCircuitBreaker
|
||||
metricsCollector *MetricsCollector
|
||||
messageStore MessageStore
|
||||
isShutdown int32
|
||||
shutdown chan struct{}
|
||||
wg sync.WaitGroup
|
||||
logger logger.Logger
|
||||
connectionPool *ConnectionPool
|
||||
healthChecker *HealthChecker
|
||||
circuitBreaker *EnhancedCircuitBreaker
|
||||
metricsCollector *MetricsCollector
|
||||
messageStore MessageStore
|
||||
securityManager *SecurityManager
|
||||
adminServer *AdminServer
|
||||
metricsServer *MetricsServer
|
||||
authenticatedConns storage.IMap[string, bool] // authenticated connections
|
||||
isShutdown int32
|
||||
shutdown chan struct{}
|
||||
wg sync.WaitGroup
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewBroker(opts ...Option) *Broker {
|
||||
@@ -491,13 +501,38 @@ func NewBroker(opts ...Option) *Broker {
|
||||
opts: options,
|
||||
|
||||
// Enhanced production features
|
||||
connectionPool: NewConnectionPool(1000), // max 1000 connections
|
||||
healthChecker: NewHealthChecker(),
|
||||
circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout
|
||||
metricsCollector: NewMetricsCollector(),
|
||||
messageStore: NewInMemoryMessageStore(),
|
||||
shutdown: make(chan struct{}),
|
||||
logger: options.Logger(),
|
||||
connectionPool: NewConnectionPool(1000), // max 1000 connections
|
||||
healthChecker: NewHealthChecker(),
|
||||
circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout
|
||||
metricsCollector: NewMetricsCollector(),
|
||||
messageStore: NewInMemoryMessageStore(),
|
||||
authenticatedConns: memory.New[string, bool](),
|
||||
shutdown: make(chan struct{}),
|
||||
logger: options.Logger(),
|
||||
}
|
||||
|
||||
if options.enableSecurity {
|
||||
broker.securityManager = NewSecurityManager()
|
||||
}
|
||||
if options.enableMonitoring {
|
||||
if options.adminAddr != "" {
|
||||
broker.adminServer = NewAdminServer(broker, options.adminAddr, options.Logger())
|
||||
}
|
||||
if options.metricsAddr != "" {
|
||||
// Need to create MonitoringConfig, use default
|
||||
config := &MonitoringConfig{
|
||||
EnableMetrics: true,
|
||||
MetricsPort: 9090, // default
|
||||
MetricsPath: "/metrics",
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckPort: 8080,
|
||||
HealthCheckPath: "/health",
|
||||
HealthCheckInterval: time.Minute,
|
||||
EnableLogging: true,
|
||||
LogLevel: "info",
|
||||
}
|
||||
broker.metricsServer = NewMetricsServer(broker, config, options.Logger())
|
||||
}
|
||||
}
|
||||
|
||||
broker.healthChecker.broker = broker
|
||||
@@ -508,6 +543,18 @@ func (b *Broker) Options() *Options {
|
||||
return b.opts
|
||||
}
|
||||
|
||||
func (b *Broker) SecurityManager() *SecurityManager {
|
||||
return b.securityManager
|
||||
}
|
||||
|
||||
// InitializeSecurity initializes default users, roles, and permissions for development/testing
|
||||
func (b *Broker) InitializeSecurity() error {
|
||||
if b.securityManager == nil {
|
||||
return fmt.Errorf("security manager not initialized")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
if ok && consumerID != "" {
|
||||
@@ -553,6 +600,10 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
|
||||
b.publishers.Del(publisherID)
|
||||
}
|
||||
}
|
||||
// Remove from authenticated connections
|
||||
connID := conn.RemoteAddr().String()
|
||||
b.authenticatedConns.Del(connID)
|
||||
|
||||
log.Printf("BROKER - Connection closed: address %s", conn.RemoteAddr())
|
||||
return nil
|
||||
}
|
||||
@@ -563,8 +614,29 @@ func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) isAuthenticated(connID string) bool {
|
||||
if b.securityManager == nil {
|
||||
return true // no security, allow all
|
||||
}
|
||||
_, ok := b.authenticatedConns.Get(connID)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
connID := conn.RemoteAddr().String()
|
||||
|
||||
// Check authentication for protected commands
|
||||
if b.securityManager != nil && (msg.Command == consts.PUBLISH || msg.Command == consts.SUBSCRIBE) {
|
||||
if !b.isAuthenticated(connID) {
|
||||
b.logger.Warn("Unauthenticated access attempt", logger.Field{Key: "command", Value: msg.Command.String()}, logger.Field{Key: "conn", Value: connID})
|
||||
// Send error response
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch msg.Command {
|
||||
case consts.AUTH:
|
||||
b.AuthHandler(ctx, conn, msg)
|
||||
case consts.PUBLISH:
|
||||
b.PublishHandler(ctx, conn, msg)
|
||||
case consts.SUBSCRIBE:
|
||||
@@ -588,6 +660,54 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) AuthHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
connID := conn.RemoteAddr().String()
|
||||
|
||||
// Parse auth credentials from payload
|
||||
var authReq map[string]any
|
||||
if err := json.Unmarshal(msg.Payload, &authReq); err != nil {
|
||||
b.logger.Error("Invalid auth request", logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
username, _ := authReq["username"].(string)
|
||||
password, _ := authReq["password"].(string)
|
||||
|
||||
// Authenticate
|
||||
user, err := b.securityManager.Authenticate(ctx, map[string]any{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})
|
||||
if err != nil {
|
||||
b.logger.Warn("Authentication failed", logger.Field{Key: "username", Value: username}, logger.Field{Key: "conn", Value: connID})
|
||||
// Send AUTH_DENY
|
||||
denyMsg, err := codec.NewMessage(consts.AUTH_DENY, []byte(fmt.Sprintf(`{"error":"%s"}`, err.Error())), "", msg.Headers)
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to create AUTH_DENY message", logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := b.send(ctx, conn, denyMsg); err != nil {
|
||||
b.logger.Error("Failed to send AUTH_DENY", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as authenticated
|
||||
b.authenticatedConns.Set(connID, true)
|
||||
|
||||
// Send AUTH_ACK
|
||||
ackMsg, err := codec.NewMessage(consts.AUTH_ACK, []byte(`{"status":"authenticated"}`), "", msg.Headers)
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to create AUTH_ACK message", logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := b.send(ctx, conn, ackMsg); err != nil {
|
||||
b.logger.Error("Failed to send AUTH_ACK", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
|
||||
b.logger.Info("User authenticated", logger.Field{Key: "username", Value: user.Username}, logger.Field{Key: "conn", Value: connID})
|
||||
}
|
||||
|
||||
func (b *Broker) AdjustConsumerWorkers(noOfWorkers int, consumerID ...string) {
|
||||
b.consumers.ForEach(func(_ string, c *consumer) bool {
|
||||
return true
|
||||
@@ -1498,10 +1618,43 @@ func (b *Broker) StartEnhanced(ctx context.Context) error {
|
||||
b.wg.Add(1)
|
||||
go b.messageStoreCleanupRoutine()
|
||||
|
||||
// Start admin server if enabled
|
||||
if b.adminServer != nil {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if err := b.adminServer.Start(); err != nil {
|
||||
b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start metrics server if enabled
|
||||
if b.metricsServer != nil {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if err := b.metricsServer.Start(ctx); err != nil {
|
||||
b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
b.logger.Info("Enhanced broker starting with production features enabled")
|
||||
|
||||
// Start the enhanced broker with its own implementation
|
||||
return b.startEnhancedBroker(ctx)
|
||||
if err := b.startEnhancedBroker(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for shutdown signal
|
||||
<-b.shutdown
|
||||
b.logger.Info("Enhanced broker shutting down")
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
b.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startEnhancedBroker starts the core broker functionality
|
||||
|
Reference in New Issue
Block a user