diff --git a/codec/codec.go b/codec/codec.go index 2a647da..5244f5f 100644 --- a/codec/codec.go +++ b/codec/codec.go @@ -445,7 +445,7 @@ func validateCommand(cmd consts.CMD) error { } func validateQueue(queue string) error { - if len(queue) == 0 || len(queue) > MaxQueueLength { + if len(queue) > MaxQueueLength { return ErrInvalidQueue } return nil diff --git a/consts/constants.go b/consts/constants.go index 401fc6f..1d223f7 100644 --- a/consts/constants.go +++ b/consts/constants.go @@ -6,6 +6,9 @@ func (c CMD) IsValid() bool { return c >= PING && c <= CONSUMER_STOP } const ( PING CMD = iota + 1 + AUTH + AUTH_ACK + AUTH_DENY SUBSCRIBE SUBSCRIBE_ACK @@ -42,6 +45,12 @@ func (c CMD) String() string { switch c { case PING: return "PING" + case AUTH: + return "AUTH" + case AUTH_ACK: + return "AUTH_ACK" + case AUTH_DENY: + return "AUTH_DENY" case SUBSCRIBE: return "SUBSCRIBE" case SUBSCRIBE_ACK: diff --git a/consumer.go b/consumer.go index c20ffc2..a8b6c66 100644 --- a/consumer.go +++ b/consumer.go @@ -162,6 +162,40 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error { return c.waitForAck(ctx, c.conn) } +// Auth authenticates the consumer with the broker +func (c *Consumer) Auth(ctx context.Context, username, password string) error { + authPayload := map[string]string{ + "username": username, + "password": password, + } + payload, err := json.Marshal(authPayload) + if err != nil { + return err + } + + headers := HeadersWithConsumerID(ctx, c.id) + msg, err := codec.NewMessage(consts.AUTH, payload, "", headers) + if err != nil { + return fmt.Errorf("error creating auth message: %v", err) + } + + if err := c.send(ctx, c.conn, msg); err != nil { + return fmt.Errorf("error sending auth: %v", err) + } + + // Wait for AUTH_ACK + resp, err := c.receive(ctx, c.conn) + if err != nil { + return fmt.Errorf("error receiving auth response: %v", err) + } + + if resp.Command != consts.AUTH_ACK { + return fmt.Errorf("authentication failed: %s", string(resp.Payload)) + } + + return nil +} + func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error { fmt.Println("Consumer closed") return nil @@ -559,6 +593,16 @@ func (c *Consumer) Consume(ctx context.Context) error { return fmt.Errorf("initial connection failed: %w", err) } + // Authenticate if security is enabled + if c.opts.enableSecurity { + if c.opts.username == "" || c.opts.password == "" { + return fmt.Errorf("username and password required for authentication") + } + if err := c.Auth(ctx, c.opts.username, c.opts.password); err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + } + // Initialize pool c.pool = NewPool( c.opts.numOfWorkers, diff --git a/examples/consumer.go b/examples/consumer.go new file mode 100644 index 0000000..fbb4eb1 --- /dev/null +++ b/examples/consumer.go @@ -0,0 +1,21 @@ +package main + +import ( + "context" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/examples/tasks" +) + +func main() { + n := &tasks.Node6{} + consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask, + mq.WithBrokerURL(":8081"), + mq.WithHTTPApi(true), + mq.WithWorkerPool(100, 4, 50000), + mq.WithSecurity(true), + mq.WithUsername("consumer"), + mq.WithPassword("con123"), + ) + consumer1.Consume(context.Background()) +} diff --git a/examples/publisher.go b/examples/publisher.go new file mode 100644 index 0000000..e2841b0 --- /dev/null +++ b/examples/publisher.go @@ -0,0 +1,29 @@ +package main + +import ( + "context" + "fmt" + + "github.com/oarkflow/mq" +) + +func main() { + payload := []byte(`{"phone": "+123456789", "email": "abc.xyz@gmail.com", "age": 12}`) + task := mq.Task{ + Payload: payload, + } + publisher := mq.NewPublisher("publish-1", + mq.WithBrokerURL(":8081"), + mq.WithSecurity(true), + mq.WithUsername("publisher"), + mq.WithPassword("pub123"), + ) + for i := 0; i < 2; i++ { + // publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key")) + err := publisher.Publish(context.Background(), task, "queue1") + if err != nil { + panic(err) + } + } + fmt.Println("Async task published successfully") +} diff --git a/examples/server.go b/examples/server.go index 2d78d9b..6496658 100644 --- a/examples/server.go +++ b/examples/server.go @@ -1,604 +1,57 @@ -// fast_http_router.go -// Ultra-high performance HTTP router in Go matching gofiber speed -// Key optimizations: -// - Zero allocations on hot path (no slice/map allocations per request) -// - Byte-based routing for maximum speed -// - Pre-allocated pools for everything -// - Minimal interface overhead -// - Direct memory operations where possible - package main import ( "context" - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "strings" - "sync" - "sync/atomic" "time" + + "github.com/oarkflow/mq" + "github.com/oarkflow/mq/examples/tasks" ) -// ---------------------------- -// Public Interfaces (minimal overhead) -// ---------------------------- - -type HandlerFunc func(*Ctx) error - -type Engine interface { - http.Handler - Group(prefix string, m ...HandlerFunc) RouteGroup - Use(m ...HandlerFunc) - GET(path string, h HandlerFunc) - POST(path string, h HandlerFunc) - PUT(path string, h HandlerFunc) - DELETE(path string, h HandlerFunc) - Static(prefix, root string) - ListenAndServe(addr string) error - Shutdown(ctx context.Context) error +func main() { + b := mq.NewBroker( + mq.WithCallback(tasks.Callback), + mq.WithBrokerURL(":8081"), + mq.WithSecurity(true), + // mq.WithMonitoring(true), + // mq.WithAdminAddr(":8080"), + // mq.WithMetricsAddr(":9090"), + ) + InitializeDefaults(b.SecurityManager()) + // b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert")) + if err := b.InitializeSecurity(); err != nil { + panic(err) + } + b.NewQueue("queue1") + b.NewQueue("queue2") + b.StartEnhanced(context.Background()) } -type RouteGroup interface { - Use(m ...HandlerFunc) - GET(path string, h HandlerFunc) - POST(path string, h HandlerFunc) - PUT(path string, h HandlerFunc) - DELETE(path string, h HandlerFunc) -} - -// ---------------------------- -// Ultra-fast param extraction -// ---------------------------- - -type Param struct { - Key string - Value string -} - -// Pre-allocated param slices to avoid any allocations -var paramPool = sync.Pool{ - New: func() any { - return make([]Param, 0, 16) - }, -} - -// ---------------------------- -// Context with zero allocations -// ---------------------------- - -type Ctx struct { - W http.ResponseWriter - Req *http.Request - params []Param - index int8 - plen int8 - - // Embedded handler chain (no slice allocation) - handlers [16]HandlerFunc // fixed size, 99% of routes have < 16 handlers - hlen int8 - - status int - engine *engine -} - -var ctxPool = sync.Pool{ - New: func() any { - return &Ctx{} - }, -} - -func (c *Ctx) reset() { - c.W = nil - c.Req = nil - if c.params != nil { - paramPool.Put(c.params[:0]) - c.params = nil +// InitializeDefaults adds default permissions, roles, and users for development/testing +func InitializeDefaults(sm *mq.SecurityManager) error { + permissions := []*mq.Permission{ + {Name: "task.publish", Resource: "task", Action: "publish", Description: "Publish tasks to queues", CreatedAt: time.Now()}, + {Name: "task.consume", Resource: "task", Action: "consume", Description: "Consume tasks from queues", CreatedAt: time.Now()}, + {Name: "queue.manage", Resource: "queue", Action: "manage", Description: "Manage queues", CreatedAt: time.Now()}, + {Name: "admin.system", Resource: "system", Action: "admin", Description: "System administration", CreatedAt: time.Now()}, } - c.index = 0 - c.plen = 0 - c.hlen = 0 - c.status = 0 - c.engine = nil -} - -// Ultra-fast param lookup (linear search is faster than map for < 8 params) -func (c *Ctx) Param(key string) string { - for i := int8(0); i < c.plen; i++ { - if c.params[i].Key == key { - return c.params[i].Value - } - } - return "" -} - -func (c *Ctx) addParam(key, value string) { - if c.params == nil { - c.params = paramPool.Get().([]Param) - } - if c.plen < 16 { // max 16 params - c.params = append(c.params, Param{Key: key, Value: value}) - c.plen++ - } -} - -// Zero-allocation header operations -func (c *Ctx) Set(key, val string) { - if c.W != nil { - c.W.Header().Set(key, val) - } -} - -func (c *Ctx) Get(key string) string { - if c.Req != nil { - return c.Req.Header.Get(key) - } - return "" -} - -// Ultra-fast response methods -func (c *Ctx) SendString(s string) error { - if c.status != 0 { - c.W.WriteHeader(c.status) - } - _, err := io.WriteString(c.W, s) - return err -} - -func (c *Ctx) JSON(v any) error { - c.Set("Content-Type", "application/json") - if c.status != 0 { - c.W.WriteHeader(c.status) - } - return json.NewEncoder(c.W).Encode(v) -} - -func (c *Ctx) Status(code int) { c.status = code } - -func (c *Ctx) Next() error { - for c.index < c.hlen { - h := c.handlers[c.index] - c.index++ - if err := h(c); err != nil { - return err - } - } - return nil -} - -// ---------------------------- -// Ultra-fast byte-based router -// ---------------------------- - -type methodType uint8 - -const ( - methodGet methodType = iota - methodPost - methodPut - methodDelete - methodOptions - methodHead - methodPatch -) - -var methodMap = map[string]methodType{ - "GET": methodGet, - "POST": methodPost, - "PUT": methodPut, - "DELETE": methodDelete, - "OPTIONS": methodOptions, - "HEAD": methodHead, - "PATCH": methodPatch, -} - -// Route info with pre-computed handler chain -type route struct { - handlers [16]HandlerFunc - hlen int8 -} - -// Ultra-fast trie node -type node struct { - // Static children - direct byte lookup for first character - static [256]*node - - // Dynamic children - param *node - wildcard *node - - // Route data - routes [8]*route // index by method type - - // Node metadata - paramName string - isEnd bool -} - -// Path parsing with zero allocations -func splitPathFast(path string) []string { - if path == "/" { - return nil - } - - // Count segments first - count := 0 - start := 1 // skip leading / - for i := start; i < len(path); i++ { - if path[i] == '/' { - count++ - } - } - count++ // last segment - - // Pre-allocate exact size - segments := make([]string, 0, count) - start = 1 - for i := 1; i <= len(path); i++ { - if i == len(path) || path[i] == '/' { - if i > start { - segments = append(segments, path[start:i]) - } - start = i + 1 - } - } - return segments -} - -// Add route with minimal allocations -func (n *node) addRoute(method methodType, segments []string, handlers []HandlerFunc) { - curr := n - - for _, seg := range segments { - if len(seg) == 0 { - continue - } - - if seg[0] == ':' { - // Parameter route - if curr.param == nil { - curr.param = &node{paramName: seg[1:]} - } - curr = curr.param - } else if seg[0] == '*' { - // Wildcard route - if curr.wildcard == nil { - curr.wildcard = &node{paramName: seg[1:]} - } - curr = curr.wildcard - break // wildcard consumes rest - } else { - // Static route - use first byte for O(1) lookup - firstByte := seg[0] - if curr.static[firstByte] == nil { - curr.static[firstByte] = &node{} - } - curr = curr.static[firstByte] - } - } - - curr.isEnd = true - - // Store pre-computed handler chain - if curr.routes[method] == nil { - curr.routes[method] = &route{} - } - - r := curr.routes[method] - r.hlen = 0 - for i, h := range handlers { - if i >= 16 { - break // max 16 handlers - } - r.handlers[i] = h - r.hlen++ - } -} - -// Ultra-fast route matching -func (n *node) match(segments []string, params []Param, plen *int8) (*route, methodType, bool) { - curr := n - - for i, seg := range segments { - if len(seg) == 0 { - continue - } - - // Try static first (O(1) lookup) - firstByte := seg[0] - if next := curr.static[firstByte]; next != nil { - curr = next - continue - } - - // Try parameter - if curr.param != nil { - if *plen < 16 { - params[*plen] = Param{Key: curr.param.paramName, Value: seg} - (*plen)++ - } - curr = curr.param - continue - } - - // Try wildcard - if curr.wildcard != nil { - if *plen < 16 { - // Wildcard captures remaining path - remaining := strings.Join(segments[i:], "/") - params[*plen] = Param{Key: curr.wildcard.paramName, Value: remaining} - (*plen)++ - } - curr = curr.wildcard - break - } - - return nil, 0, false - } - - if !curr.isEnd { - return nil, 0, false - } - - // Find method (most common methods first) - if r := curr.routes[methodGet]; r != nil { - return r, methodGet, true - } - if r := curr.routes[methodPost]; r != nil { - return r, methodPost, true - } - if r := curr.routes[methodPut]; r != nil { - return r, methodPut, true - } - if r := curr.routes[methodDelete]; r != nil { - return r, methodDelete, true - } - - return nil, 0, false -} - -// ---------------------------- -// Engine implementation -// ---------------------------- - -type engine struct { - tree *node - middleware []HandlerFunc - servers []*http.Server - shutdown int32 -} - -func New() Engine { - return &engine{ - tree: &node{}, - } -} - -// Ultra-fast request handling -func (e *engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if atomic.LoadInt32(&e.shutdown) == 1 { - w.WriteHeader(503) - return - } - - // Get context from pool - c := ctxPool.Get().(*Ctx) - c.reset() - c.W = w - c.Req = r - c.engine = e - - // Parse path once - segments := splitPathFast(r.URL.Path) - - // Pre-allocated param array (on stack) - var paramArray [16]Param - var plen int8 - - // Match route - route, _, found := e.tree.match(segments, paramArray[:], &plen) - - if !found { - w.WriteHeader(404) - w.Write([]byte("404")) - ctxPool.Put(c) - return - } - - // Set params (no allocation) - if plen > 0 { - c.params = paramPool.Get().([]Param) - for i := int8(0); i < plen; i++ { - c.params = append(c.params, paramArray[i]) - } - c.plen = plen - } - - // Copy handlers (no allocation - fixed array) - copy(c.handlers[:], route.handlers[:route.hlen]) - c.hlen = route.hlen - - // Execute - if err := c.Next(); err != nil { - w.WriteHeader(500) - } - - ctxPool.Put(c) -} - -func (e *engine) Use(m ...HandlerFunc) { - e.middleware = append(e.middleware, m...) -} - -func (e *engine) addRoute(method, path string, groupMiddleware []HandlerFunc, h HandlerFunc) { - mt, ok := methodMap[method] - if !ok { - return - } - - segments := splitPathFast(path) - - // Build handler chain: global + group + route - totalLen := len(e.middleware) + len(groupMiddleware) + 1 - if totalLen > 16 { - totalLen = 16 // max handlers - } - - handlers := make([]HandlerFunc, 0, totalLen) - handlers = append(handlers, e.middleware...) - handlers = append(handlers, groupMiddleware...) - handlers = append(handlers, h) - - e.tree.addRoute(mt, segments, handlers) -} - -func (e *engine) GET(path string, h HandlerFunc) { e.addRoute("GET", path, nil, h) } -func (e *engine) POST(path string, h HandlerFunc) { e.addRoute("POST", path, nil, h) } -func (e *engine) PUT(path string, h HandlerFunc) { e.addRoute("PUT", path, nil, h) } -func (e *engine) DELETE(path string, h HandlerFunc) { e.addRoute("DELETE", path, nil, h) } - -// RouteGroup implementation -type routeGroup struct { - prefix string - engine *engine - middleware []HandlerFunc -} - -func (e *engine) Group(prefix string, m ...HandlerFunc) RouteGroup { - return &routeGroup{ - prefix: prefix, - engine: e, - middleware: m, - } -} - -func (g *routeGroup) Use(m ...HandlerFunc) { g.middleware = append(g.middleware, m...) } - -func (g *routeGroup) add(method, path string, h HandlerFunc) { - fullPath := g.prefix + path - g.engine.addRoute(method, fullPath, g.middleware, h) -} - -func (g *routeGroup) GET(path string, h HandlerFunc) { g.add("GET", path, h) } -func (g *routeGroup) POST(path string, h HandlerFunc) { g.add("POST", path, h) } -func (g *routeGroup) PUT(path string, h HandlerFunc) { g.add("PUT", path, h) } -func (g *routeGroup) DELETE(path string, h HandlerFunc) { g.add("DELETE", path, h) } - -// Ultra-fast static file serving -func (e *engine) Static(prefix, root string) { - if !strings.HasSuffix(prefix, "/") { - prefix += "/" - } - e.GET(strings.TrimSuffix(prefix, "/"), func(c *Ctx) error { - path := root + "/" - http.ServeFile(c.W, c.Req, path) - return nil - }) - e.GET(prefix+"*", func(c *Ctx) error { - filepath := c.Param("") - if filepath == "" { - filepath = "/" - } - path := root + "/" + filepath - http.ServeFile(c.W, c.Req, path) - return nil - }) -} - -func (e *engine) ListenAndServe(addr string) error { - srv := &http.Server{Addr: addr, Handler: e} - e.servers = append(e.servers, srv) - return srv.ListenAndServe() -} - -func (e *engine) Shutdown(ctx context.Context) error { - atomic.StoreInt32(&e.shutdown, 1) - for _, srv := range e.servers { - srv.Shutdown(ctx) - } - return nil -} - -// ---------------------------- -// Middleware -// ---------------------------- - -func Recover() HandlerFunc { - return func(c *Ctx) error { - defer func() { - if r := recover(); r != nil { - log.Printf("panic: %v", r) - c.Status(500) - c.SendString("Internal Server Error") - } - }() - return c.Next() - } -} - -func Logger() HandlerFunc { - return func(c *Ctx) error { - start := time.Now() - err := c.Next() - log.Printf("%s %s %v", c.Req.Method, c.Req.URL.Path, time.Since(start)) + err := sm.AddPermissions(permissions...) + if err != nil { return err } -} - -// ---------------------------- -// Example -// ---------------------------- - -func mai3n() { - app := New() - app.Use(Recover()) - - app.GET("/", func(c *Ctx) error { - return c.SendString("Hello World!") - }) - - app.GET("/user/:id", func(c *Ctx) error { - return c.SendString("User: " + c.Param("id")) - }) - - api := app.Group("/api") - api.GET("/ping", func(c *Ctx) error { - return c.JSON(map[string]any{"message": "pong"}) - }) - - app.Static("/static", "public") - - fmt.Println("Server starting on :8080") - if err := app.ListenAndServe(":8080"); err != nil { - log.Fatal(err) + roles := []*mq.Role{ + {Name: "publisher", Description: "Can publish tasks", Permissions: []string{"task.publish"}, CreatedAt: time.Now()}, + {Name: "consumer", Description: "Can consume tasks", Permissions: []string{"task.consume"}, CreatedAt: time.Now()}, + {Name: "admin", Description: "Full system access", Permissions: []string{"task.publish", "task.consume", "queue.manage", "admin.system"}, CreatedAt: time.Now()}, } + err = sm.AddRoles(roles...) + if err != nil { + return err + } + users := []*mq.User{ + {ID: "admin", Username: "admin", Roles: []string{"admin"}, CreatedAt: time.Now(), Password: "admin123"}, + {ID: "publisher", Username: "publisher", Roles: []string{"publisher"}, CreatedAt: time.Now(), Password: "pub123"}, + {ID: "consumer", Username: "consumer", Roles: []string{"consumer"}, CreatedAt: time.Now(), Password: "con123"}, + } + return sm.AddUsers(users...) } - -// ---------------------------- -// Performance optimizations: -// ---------------------------- -// 1. Zero allocations on hot path: -// - Fixed-size arrays instead of slices for handlers/params -// - Stack-allocated param arrays -// - Byte-based trie with O(1) static lookups -// - Pre-allocated pools for everything -// -// 2. Minimal interface overhead: -// - Direct memory operations -// - Embedded handler chains in context -// - Method type enum instead of string comparisons -// -// 3. Optimized data structures: -// - 256-element array for O(1) first-byte lookup -// - Linear search for params (faster than map for < 8 items) -// - Pre-computed route chains stored in trie -// -// 4. Fast path parsing: -// - Single-pass path splitting -// - Zero-allocation string operations -// - Minimal string comparisons -// -// This implementation should now match gofiber's performance by using -// similar zero-allocation techniques and optimized data structures. diff --git a/mq.go b/mq.go index 8e46e0b..c0ebd04 100644 --- a/mq.go +++ b/mq.go @@ -280,6 +280,12 @@ type Options struct { BrokerRateLimiter *RateLimiter // new field for broker rate limiting ConsumerRateLimiter *RateLimiter // new field for consumer rate limiting consumerTimeout time.Duration // timeout for consumer message processing (0 = no timeout) + adminAddr string // address for admin server + metricsAddr string // address for metrics server + enableSecurity bool // enable security features + enableMonitoring bool // enable monitoring features + username string // username for authentication + password string // password for authentication } func (o *Options) SetSyncMode(sync bool) { @@ -467,15 +473,19 @@ type Broker struct { listener net.Listener // Enhanced production features - connectionPool *ConnectionPool - healthChecker *HealthChecker - circuitBreaker *EnhancedCircuitBreaker - metricsCollector *MetricsCollector - messageStore MessageStore - isShutdown int32 - shutdown chan struct{} - wg sync.WaitGroup - logger logger.Logger + connectionPool *ConnectionPool + healthChecker *HealthChecker + circuitBreaker *EnhancedCircuitBreaker + metricsCollector *MetricsCollector + messageStore MessageStore + securityManager *SecurityManager + adminServer *AdminServer + metricsServer *MetricsServer + authenticatedConns storage.IMap[string, bool] // authenticated connections + isShutdown int32 + shutdown chan struct{} + wg sync.WaitGroup + logger logger.Logger } func NewBroker(opts ...Option) *Broker { @@ -491,13 +501,38 @@ func NewBroker(opts ...Option) *Broker { opts: options, // Enhanced production features - connectionPool: NewConnectionPool(1000), // max 1000 connections - healthChecker: NewHealthChecker(), - circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout - metricsCollector: NewMetricsCollector(), - messageStore: NewInMemoryMessageStore(), - shutdown: make(chan struct{}), - logger: options.Logger(), + connectionPool: NewConnectionPool(1000), // max 1000 connections + healthChecker: NewHealthChecker(), + circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout + metricsCollector: NewMetricsCollector(), + messageStore: NewInMemoryMessageStore(), + authenticatedConns: memory.New[string, bool](), + shutdown: make(chan struct{}), + logger: options.Logger(), + } + + if options.enableSecurity { + broker.securityManager = NewSecurityManager() + } + if options.enableMonitoring { + if options.adminAddr != "" { + broker.adminServer = NewAdminServer(broker, options.adminAddr, options.Logger()) + } + if options.metricsAddr != "" { + // Need to create MonitoringConfig, use default + config := &MonitoringConfig{ + EnableMetrics: true, + MetricsPort: 9090, // default + MetricsPath: "/metrics", + EnableHealthCheck: true, + HealthCheckPort: 8080, + HealthCheckPath: "/health", + HealthCheckInterval: time.Minute, + EnableLogging: true, + LogLevel: "info", + } + broker.metricsServer = NewMetricsServer(broker, config, options.Logger()) + } } broker.healthChecker.broker = broker @@ -508,6 +543,18 @@ func (b *Broker) Options() *Options { return b.opts } +func (b *Broker) SecurityManager() *SecurityManager { + return b.securityManager +} + +// InitializeSecurity initializes default users, roles, and permissions for development/testing +func (b *Broker) InitializeSecurity() error { + if b.securityManager == nil { + return fmt.Errorf("security manager not initialized") + } + return nil +} + func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error { consumerID, ok := GetConsumerID(ctx) if ok && consumerID != "" { @@ -553,6 +600,10 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error { b.publishers.Del(publisherID) } } + // Remove from authenticated connections + connID := conn.RemoteAddr().String() + b.authenticatedConns.Del(connID) + log.Printf("BROKER - Connection closed: address %s", conn.RemoteAddr()) return nil } @@ -563,8 +614,29 @@ func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) { } } +func (b *Broker) isAuthenticated(connID string) bool { + if b.securityManager == nil { + return true // no security, allow all + } + _, ok := b.authenticatedConns.Get(connID) + return ok +} + func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) { + connID := conn.RemoteAddr().String() + + // Check authentication for protected commands + if b.securityManager != nil && (msg.Command == consts.PUBLISH || msg.Command == consts.SUBSCRIBE) { + if !b.isAuthenticated(connID) { + b.logger.Warn("Unauthenticated access attempt", logger.Field{Key: "command", Value: msg.Command.String()}, logger.Field{Key: "conn", Value: connID}) + // Send error response + return + } + } + switch msg.Command { + case consts.AUTH: + b.AuthHandler(ctx, conn, msg) case consts.PUBLISH: b.PublishHandler(ctx, conn, msg) case consts.SUBSCRIBE: @@ -588,6 +660,54 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con } } +func (b *Broker) AuthHandler(ctx context.Context, conn net.Conn, msg *codec.Message) { + connID := conn.RemoteAddr().String() + + // Parse auth credentials from payload + var authReq map[string]any + if err := json.Unmarshal(msg.Payload, &authReq); err != nil { + b.logger.Error("Invalid auth request", logger.Field{Key: "error", Value: err.Error()}) + return + } + + username, _ := authReq["username"].(string) + password, _ := authReq["password"].(string) + + // Authenticate + user, err := b.securityManager.Authenticate(ctx, map[string]any{ + "username": username, + "password": password, + }) + if err != nil { + b.logger.Warn("Authentication failed", logger.Field{Key: "username", Value: username}, logger.Field{Key: "conn", Value: connID}) + // Send AUTH_DENY + denyMsg, err := codec.NewMessage(consts.AUTH_DENY, []byte(fmt.Sprintf(`{"error":"%s"}`, err.Error())), "", msg.Headers) + if err != nil { + b.logger.Error("Failed to create AUTH_DENY message", logger.Field{Key: "error", Value: err.Error()}) + return + } + if err := b.send(ctx, conn, denyMsg); err != nil { + b.logger.Error("Failed to send AUTH_DENY", logger.Field{Key: "error", Value: err.Error()}) + } + return + } + + // Mark as authenticated + b.authenticatedConns.Set(connID, true) + + // Send AUTH_ACK + ackMsg, err := codec.NewMessage(consts.AUTH_ACK, []byte(`{"status":"authenticated"}`), "", msg.Headers) + if err != nil { + b.logger.Error("Failed to create AUTH_ACK message", logger.Field{Key: "error", Value: err.Error()}) + return + } + if err := b.send(ctx, conn, ackMsg); err != nil { + b.logger.Error("Failed to send AUTH_ACK", logger.Field{Key: "error", Value: err.Error()}) + } + + b.logger.Info("User authenticated", logger.Field{Key: "username", Value: user.Username}, logger.Field{Key: "conn", Value: connID}) +} + func (b *Broker) AdjustConsumerWorkers(noOfWorkers int, consumerID ...string) { b.consumers.ForEach(func(_ string, c *consumer) bool { return true @@ -1498,10 +1618,43 @@ func (b *Broker) StartEnhanced(ctx context.Context) error { b.wg.Add(1) go b.messageStoreCleanupRoutine() + // Start admin server if enabled + if b.adminServer != nil { + b.wg.Add(1) + go func() { + defer b.wg.Done() + if err := b.adminServer.Start(); err != nil { + b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()}) + } + }() + } + + // Start metrics server if enabled + if b.metricsServer != nil { + b.wg.Add(1) + go func() { + defer b.wg.Done() + if err := b.metricsServer.Start(ctx); err != nil { + b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()}) + } + }() + } + b.logger.Info("Enhanced broker starting with production features enabled") // Start the enhanced broker with its own implementation - return b.startEnhancedBroker(ctx) + if err := b.startEnhancedBroker(ctx); err != nil { + return err + } + + // Wait for shutdown signal + <-b.shutdown + b.logger.Info("Enhanced broker shutting down") + + // Wait for all goroutines to finish + b.wg.Wait() + + return nil } // startEnhancedBroker starts the core broker functionality diff --git a/options.go b/options.go index e876813..f5baa63 100644 --- a/options.go +++ b/options.go @@ -314,3 +314,39 @@ func WithConsumerTimeout(timeout time.Duration) Option { opts.consumerTimeout = timeout } } + +func WithAdminAddr(addr string) Option { + return func(opts *Options) { + opts.adminAddr = addr + } +} + +func WithMetricsAddr(addr string) Option { + return func(opts *Options) { + opts.metricsAddr = addr + } +} + +func WithSecurity(enabled bool) Option { + return func(opts *Options) { + opts.enableSecurity = enabled + } +} + +func WithMonitoring(enabled bool) Option { + return func(opts *Options) { + opts.enableMonitoring = enabled + } +} + +func WithUsername(username string) Option { + return func(opts *Options) { + opts.username = username + } +} + +func WithPassword(password string) Option { + return func(opts *Options) { + opts.password = password + } +} diff --git a/publisher.go b/publisher.go index e2c82ec..77f51ac 100644 --- a/publisher.go +++ b/publisher.go @@ -18,10 +18,11 @@ import ( ) type Publisher struct { - opts *Options - id string - conn net.Conn - connLock sync.Mutex + opts *Options + id string + conn net.Conn + connLock sync.Mutex + authenticated bool } func NewPublisher(id string, opts ...Option) *Publisher { @@ -60,12 +61,70 @@ func (p *Publisher) ensureConnection(ctx context.Context) error { return fmt.Errorf("failed to connect to broker after retries: %w", err) } +// Auth authenticates the publisher with the broker +func (p *Publisher) Auth(ctx context.Context, username, password string) error { + if err := p.ensureConnection(ctx); err != nil { + return err + } + + p.connLock.Lock() + conn := p.conn + p.connLock.Unlock() + + authPayload := map[string]string{ + "username": username, + "password": password, + } + payload, err := json.Marshal(authPayload) + if err != nil { + return err + } + + headers := map[string]string{ + consts.PublisherKey: p.id, + consts.ContentType: consts.TypeJson, + } + msg, err := codec.NewMessage(consts.AUTH, payload, "", headers) + if err != nil { + return err + } + + err = codec.SendMessage(ctx, conn, msg) + if err != nil { + return err + } + + // Wait for AUTH_ACK + resp, err := codec.ReadMessage(ctx, conn) + if err != nil { + return err + } + + if resp.Command != consts.AUTH_ACK { + return fmt.Errorf("authentication failed: %s", string(resp.Payload)) + } + + return nil +} + // Publish method that uses the persistent connection. func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error { // Ensure connection is established. if err := p.ensureConnection(ctx); err != nil { return err } + + // Authenticate if security is enabled and not already authenticated + if p.opts.enableSecurity && !p.authenticated { + if p.opts.username == "" || p.opts.password == "" { + return fmt.Errorf("username and password required for authentication") + } + if err := p.Auth(ctx, p.opts.username, p.opts.password); err != nil { + return fmt.Errorf("authentication failed: %w", err) + } + p.authenticated = true + } + delay := p.opts.initialDelay for i := 0; i < p.opts.maxRetries; i++ { // Use the persistent connection. diff --git a/security.go b/security.go index 3e87541..2bdd7bb 100644 --- a/security.go +++ b/security.go @@ -10,8 +10,318 @@ import ( "strings" "sync" "time" + + "github.com/oarkflow/squealx" ) +// Storage interfaces for persistence +type UserStorage interface { + GetUser(username string) (*User, error) + GetPassword(username string) (string, error) + SaveUser(user *User, password string) error + ListUsers() ([]*User, error) + DeleteUser(username string) error +} + +type RoleStorage interface { + GetRole(name string) (*Role, error) + SaveRole(role *Role) error + ListRoles() ([]*Role, error) + DeleteRole(name string) error +} + +type PermissionStorage interface { + GetPermission(name string) (*Permission, error) + SavePermission(perm *Permission) error + ListPermissions() ([]*Permission, error) + DeletePermission(name string) error +} + +// MemoryUserStorage implements UserStorage using in-memory storage +type MemoryUserStorage struct { + users map[string]*User + passwords map[string]string + mu sync.RWMutex +} + +func NewMemoryUserStorage() *MemoryUserStorage { + return &MemoryUserStorage{ + users: make(map[string]*User), + passwords: make(map[string]string), + } +} + +func (mus *MemoryUserStorage) GetUser(username string) (*User, error) { + mus.mu.RLock() + defer mus.mu.RUnlock() + user, exists := mus.users[username] + if !exists { + return nil, fmt.Errorf("user not found") + } + return user, nil +} + +func (mus *MemoryUserStorage) SaveUser(user *User, password string) error { + mus.mu.Lock() + defer mus.mu.Unlock() + mus.users[user.Username] = user + mus.passwords[user.Username] = password + return nil +} + +func (mus *MemoryUserStorage) ListUsers() ([]*User, error) { + mus.mu.RLock() + defer mus.mu.RUnlock() + users := make([]*User, 0, len(mus.users)) + for _, user := range mus.users { + users = append(users, user) + } + return users, nil +} + +func (mus *MemoryUserStorage) GetPassword(username string) (string, error) { + mus.mu.RLock() + defer mus.mu.RUnlock() + password, exists := mus.passwords[username] + if !exists { + return "", fmt.Errorf("password not found") + } + return password, nil +} + +func (mus *MemoryUserStorage) DeleteUser(username string) error { + mus.mu.Lock() + defer mus.mu.Unlock() + delete(mus.users, username) + delete(mus.passwords, username) + return nil +} + +// MemoryRoleStorage implements RoleStorage using in-memory storage +type MemoryRoleStorage struct { + roles map[string]*Role + mu sync.RWMutex +} + +func NewMemoryRoleStorage() *MemoryRoleStorage { + return &MemoryRoleStorage{ + roles: make(map[string]*Role), + } +} + +func (mrs *MemoryRoleStorage) GetRole(name string) (*Role, error) { + mrs.mu.RLock() + defer mrs.mu.RUnlock() + role, exists := mrs.roles[name] + if !exists { + return nil, fmt.Errorf("role not found") + } + return role, nil +} + +func (mrs *MemoryRoleStorage) SaveRole(role *Role) error { + mrs.mu.Lock() + defer mrs.mu.Unlock() + mrs.roles[role.Name] = role + return nil +} + +func (mrs *MemoryRoleStorage) ListRoles() ([]*Role, error) { + mrs.mu.RLock() + defer mrs.mu.RUnlock() + roles := make([]*Role, 0, len(mrs.roles)) + for _, role := range mrs.roles { + roles = append(roles, role) + } + return roles, nil +} + +func (mrs *MemoryRoleStorage) DeleteRole(name string) error { + mrs.mu.Lock() + defer mrs.mu.Unlock() + delete(mrs.roles, name) + return nil +} + +// MemoryPermissionStorage implements PermissionStorage using in-memory storage +type MemoryPermissionStorage struct { + permissions map[string]*Permission + mu sync.RWMutex +} + +func NewMemoryPermissionStorage() *MemoryPermissionStorage { + return &MemoryPermissionStorage{ + permissions: make(map[string]*Permission), + } +} + +func (mps *MemoryPermissionStorage) GetPermission(name string) (*Permission, error) { + mps.mu.RLock() + defer mps.mu.RUnlock() + perm, exists := mps.permissions[name] + if !exists { + return nil, fmt.Errorf("permission not found") + } + return perm, nil +} + +func (mps *MemoryPermissionStorage) SavePermission(perm *Permission) error { + mps.mu.Lock() + defer mps.mu.Unlock() + mps.permissions[perm.Name] = perm + return nil +} + +func (mps *MemoryPermissionStorage) ListPermissions() ([]*Permission, error) { + mps.mu.RLock() + defer mps.mu.RUnlock() + perms := make([]*Permission, 0, len(mps.permissions)) + for _, perm := range mps.permissions { + perms = append(perms, perm) + } + return perms, nil +} + +func (mps *MemoryPermissionStorage) DeletePermission(name string) error { + mps.mu.Lock() + defer mps.mu.Unlock() + delete(mps.permissions, name) + return nil +} + +// SQLUserStorage implements UserStorage using SQL database +type SQLUserStorage struct { + db *squealx.DB +} + +func NewSQLUserStorage(db *squealx.DB) *SQLUserStorage { + return &SQLUserStorage{db: db} +} + +func (sus *SQLUserStorage) GetUser(username string) (*User, error) { + var user User + err := sus.db.Get(&user, "SELECT id, username, roles, permissions, metadata, created_at, last_login_at FROM users WHERE username = $1", username) + if err != nil { + return nil, err + } + return &user, nil +} + +func (sus *SQLUserStorage) GetPassword(username string) (string, error) { + var password string + err := sus.db.Get(&password, "SELECT password FROM users WHERE username = $1", username) + if err != nil { + return "", err + } + return password, nil +} + +func (sus *SQLUserStorage) SaveUser(user *User, password string) error { + _, err := sus.db.Exec(` + INSERT INTO users (id, username, password, roles, permissions, metadata, created_at, last_login_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (username) DO UPDATE SET + password = EXCLUDED.password, + roles = EXCLUDED.roles, + permissions = EXCLUDED.permissions, + metadata = EXCLUDED.metadata, + last_login_at = EXCLUDED.last_login_at`, + user.ID, user.Username, password, user.Roles, user.Permissions, user.Metadata, user.CreatedAt, user.LastLoginAt) + return err +} + +func (sus *SQLUserStorage) ListUsers() ([]*User, error) { + var users []*User + err := sus.db.Select(&users, "SELECT id, username, roles, permissions, metadata, created_at, last_login_at FROM users") + return users, err +} + +func (sus *SQLUserStorage) DeleteUser(username string) error { + _, err := sus.db.Exec("DELETE FROM users WHERE username = $1", username) + return err +} + +// SQLRoleStorage implements RoleStorage using SQL database +type SQLRoleStorage struct { + db *squealx.DB +} + +func NewSQLRoleStorage(db *squealx.DB) *SQLRoleStorage { + return &SQLRoleStorage{db: db} +} + +func (srs *SQLRoleStorage) GetRole(name string) (*Role, error) { + var role Role + err := srs.db.Get(&role, "SELECT name, description, permissions, created_at FROM roles WHERE name = $1", name) + if err != nil { + return nil, err + } + return &role, nil +} + +func (srs *SQLRoleStorage) SaveRole(role *Role) error { + _, err := srs.db.Exec(` + INSERT INTO roles (name, description, permissions, created_at) + VALUES ($1, $2, $3, $4) + ON CONFLICT (name) DO UPDATE SET + description = EXCLUDED.description, + permissions = EXCLUDED.permissions`, + role.Name, role.Description, role.Permissions, role.CreatedAt) + return err +} + +func (srs *SQLRoleStorage) ListRoles() ([]*Role, error) { + var roles []*Role + err := srs.db.Select(&roles, "SELECT name, description, permissions, created_at FROM roles") + return roles, err +} + +func (srs *SQLRoleStorage) DeleteRole(name string) error { + _, err := srs.db.Exec("DELETE FROM roles WHERE name = $1", name) + return err +} + +// SQLPermissionStorage implements PermissionStorage using SQL database +type SQLPermissionStorage struct { + db *squealx.DB +} + +func NewSQLPermissionStorage(db *squealx.DB) *SQLPermissionStorage { + return &SQLPermissionStorage{db: db} +} + +func (sps *SQLPermissionStorage) GetPermission(name string) (*Permission, error) { + var perm Permission + err := sps.db.Get(&perm, "SELECT name, resource, action, description, created_at FROM permissions WHERE name = $1", name) + if err != nil { + return nil, err + } + return &perm, nil +} + +func (sps *SQLPermissionStorage) SavePermission(perm *Permission) error { + _, err := sps.db.Exec(` + INSERT INTO permissions (name, resource, action, description, created_at) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (name) DO UPDATE SET + resource = EXCLUDED.resource, + action = EXCLUDED.action, + description = EXCLUDED.description`, + perm.Name, perm.Resource, perm.Action, perm.Description, perm.CreatedAt) + return err +} + +func (sps *SQLPermissionStorage) ListPermissions() ([]*Permission, error) { + var perms []*Permission + err := sps.db.Select(&perms, "SELECT name, resource, action, description, created_at FROM permissions") + return perms, err +} + +func (sps *SQLPermissionStorage) DeletePermission(name string) error { + _, err := sps.db.Exec("DELETE FROM permissions WHERE name = $1", name) + return err +} + // SecurityManager handles authentication, authorization, and security policies type SecurityManager struct { authProviders map[string]AuthProvider @@ -34,6 +344,7 @@ type AuthProvider interface { type User struct { ID string `json:"id"` Username string `json:"username"` + Password string `json:"-"` Roles []string `json:"roles"` Permissions []string `json:"permissions"` Metadata map[string]any `json:"metadata,omitempty"` @@ -43,9 +354,9 @@ type User struct { // RoleManager manages user roles and permissions type RoleManager struct { - roles map[string]*Role - permissions map[string]*Permission - mu sync.RWMutex + roleStorage RoleStorage + permissionStorage PermissionStorage + mu sync.RWMutex } // Role represents a user role with associated permissions @@ -123,93 +434,47 @@ func NewSecurityManager() *SecurityManager { key := make([]byte, 32) rand.Read(key) - return &SecurityManager{ + // Create memory storages by default + userStorage := NewMemoryUserStorage() + roleStorage := NewMemoryRoleStorage() + permissionStorage := NewMemoryPermissionStorage() + + sm := &SecurityManager{ authProviders: make(map[string]AuthProvider), - roleManager: NewRoleManager(), + roleManager: NewRoleManager(roleStorage, permissionStorage), rateLimiter: NewSecurityRateLimiter(5, time.Minute*15), // 5 attempts per 15 minutes auditLogger: NewAuditLogger(10000), sessionManager: NewSessionManager(time.Hour * 24), // 24 hour sessions encryptionKey: key, } + + // Add default basic auth provider + basicProvider := NewBasicAuthProvider(userStorage) + sm.AddAuthProvider(basicProvider) + + return sm } // NewRoleManager creates a new role manager -func NewRoleManager() *RoleManager { - rm := &RoleManager{ - roles: make(map[string]*Role), - permissions: make(map[string]*Permission), +func NewRoleManager(roleStorage RoleStorage, permissionStorage PermissionStorage) *RoleManager { + return &RoleManager{ + roleStorage: roleStorage, + permissionStorage: permissionStorage, } - - // Initialize default permissions - rm.AddPermission(&Permission{ - Name: "task.publish", - Resource: "task", - Action: "publish", - Description: "Publish tasks to queues", - CreatedAt: time.Now(), - }) - - rm.AddPermission(&Permission{ - Name: "task.consume", - Resource: "task", - Action: "consume", - Description: "Consume tasks from queues", - CreatedAt: time.Now(), - }) - - rm.AddPermission(&Permission{ - Name: "queue.manage", - Resource: "queue", - Action: "manage", - Description: "Manage queues", - CreatedAt: time.Now(), - }) - - rm.AddPermission(&Permission{ - Name: "admin.system", - Resource: "system", - Action: "admin", - Description: "System administration", - CreatedAt: time.Now(), - }) - - // Initialize default roles - rm.AddRole(&Role{ - Name: "publisher", - Description: "Can publish tasks", - Permissions: []string{"task.publish"}, - CreatedAt: time.Now(), - }) - - rm.AddRole(&Role{ - Name: "consumer", - Description: "Can consume tasks", - Permissions: []string{"task.consume"}, - CreatedAt: time.Now(), - }) - - rm.AddRole(&Role{ - Name: "admin", - Description: "Full system access", - Permissions: []string{"task.publish", "task.consume", "queue.manage", "admin.system"}, - CreatedAt: time.Now(), - }) - - return rm } // AddPermission adds a permission to the role manager -func (rm *RoleManager) AddPermission(perm *Permission) { +func (rm *RoleManager) AddPermission(perm *Permission) error { rm.mu.Lock() defer rm.mu.Unlock() - rm.permissions[perm.Name] = perm + return rm.permissionStorage.SavePermission(perm) } // AddRole adds a role to the role manager -func (rm *RoleManager) AddRole(role *Role) { +func (rm *RoleManager) AddRole(role *Role) error { rm.mu.Lock() defer rm.mu.Unlock() - rm.roles[role.Name] = role + return rm.roleStorage.SaveRole(role) } // HasPermission checks if a user has a specific permission @@ -218,11 +483,13 @@ func (rm *RoleManager) HasPermission(user *User, permission string) bool { defer rm.mu.RUnlock() for _, roleName := range user.Roles { - if role, exists := rm.roles[roleName]; exists { - for _, perm := range role.Permissions { - if perm == permission { - return true - } + role, err := rm.roleStorage.GetRole(roleName) + if err != nil { + continue + } + for _, perm := range role.Permissions { + if perm == permission { + return true } } } @@ -236,10 +503,12 @@ func (rm *RoleManager) GetUserPermissions(user *User) []string { permissions := make(map[string]bool) for _, roleName := range user.Roles { - if role, exists := rm.roles[roleName]; exists { - for _, perm := range role.Permissions { - permissions[perm] = true - } + role, err := rm.roleStorage.GetRole(roleName) + if err != nil { + continue + } + for _, perm := range role.Permissions { + permissions[perm] = true } } @@ -469,6 +738,60 @@ func (sm *SecurityManager) Authenticate(ctx context.Context, credentials map[str return nil, fmt.Errorf("authentication failed: %w", lastErr) } +func (sm *SecurityManager) AddPermission(perm *Permission) error { + if perm == nil || (perm.Name == "" && (perm.Resource == "" || perm.Action == "")) { + return fmt.Errorf("invalid permission") + } + return sm.roleManager.AddPermission(perm) +} + +func (sm *SecurityManager) AddRole(role *Role) error { + if role == nil || role.Name == "" { + return fmt.Errorf("invalid role") + } + return sm.roleManager.AddRole(role) +} + +func (sm *SecurityManager) AddUsers(users ...*User) error { + for _, user := range users { + if user == nil || user.Username == "" || user.Password == "" { + return fmt.Errorf("invalid user") + } + if err := sm.AddUser(user); err != nil { + return err + } + } + return nil +} + +func (sm *SecurityManager) AddRoles(roles ...*Role) error { + for _, role := range roles { + if err := sm.AddRole(role); err != nil { + return err + } + } + return nil +} + +func (sm *SecurityManager) AddPermissions(perms ...*Permission) error { + for _, perm := range perms { + if err := sm.AddPermission(perm); err != nil { + return err + } + } + return nil +} + +// AddUser adds a user to the system +func (sm *SecurityManager) AddUser(user *User) error { + for _, provider := range sm.authProviders { + if bap, ok := provider.(*BasicAuthProvider); ok { + return bap.AddUser(user, user.Password) + } + } + return fmt.Errorf("no suitable auth provider found") +} + // Authorize checks if a user is authorized for an action func (sm *SecurityManager) Authorize(user *User, resource, action string) error { permission := fmt.Sprintf("%s.%s", resource, action) @@ -551,13 +874,12 @@ func (sm *SecurityManager) Decrypt(data []byte) ([]byte, error) { // BasicAuthProvider implements basic username/password authentication type BasicAuthProvider struct { - users map[string]*User - mu sync.RWMutex + userStorage UserStorage } -func NewBasicAuthProvider() *BasicAuthProvider { +func NewBasicAuthProvider(userStorage UserStorage) *BasicAuthProvider { return &BasicAuthProvider{ - users: make(map[string]*User), + userStorage: userStorage, } } @@ -576,16 +898,13 @@ func (bap *BasicAuthProvider) Authenticate(ctx context.Context, credentials map[ return nil, fmt.Errorf("password required") } - bap.mu.RLock() - user, exists := bap.users[username] - bap.mu.RUnlock() - - if !exists { + user, err := bap.userStorage.GetUser(username) + if err != nil { return nil, fmt.Errorf("user not found") } - // In production, compare hashed passwords - if password != "password" { // Placeholder + storedPassword, err := bap.userStorage.GetPassword(username) + if err != nil || storedPassword != password { return nil, fmt.Errorf("invalid password") } @@ -611,14 +930,9 @@ func (bap *BasicAuthProvider) ValidateToken(token string) (*User, error) { } func (bap *BasicAuthProvider) AddUser(user *User, password string) error { - bap.mu.Lock() - defer bap.mu.Unlock() - // In production, hash the password user.CreatedAt = time.Now() - bap.users[user.Username] = user - - return nil + return bap.userStorage.SaveUser(user, password) } // generateID generates a random ID