This commit is contained in:
sujit
2025-09-24 16:01:07 +05:45
parent cb2869c98b
commit ca422e5fda
10 changed files with 826 additions and 708 deletions

View File

@@ -445,7 +445,7 @@ func validateCommand(cmd consts.CMD) error {
}
func validateQueue(queue string) error {
if len(queue) == 0 || len(queue) > MaxQueueLength {
if len(queue) > MaxQueueLength {
return ErrInvalidQueue
}
return nil

View File

@@ -6,6 +6,9 @@ func (c CMD) IsValid() bool { return c >= PING && c <= CONSUMER_STOP }
const (
PING CMD = iota + 1
AUTH
AUTH_ACK
AUTH_DENY
SUBSCRIBE
SUBSCRIBE_ACK
@@ -42,6 +45,12 @@ func (c CMD) String() string {
switch c {
case PING:
return "PING"
case AUTH:
return "AUTH"
case AUTH_ACK:
return "AUTH_ACK"
case AUTH_DENY:
return "AUTH_DENY"
case SUBSCRIBE:
return "SUBSCRIBE"
case SUBSCRIBE_ACK:

View File

@@ -162,6 +162,40 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error {
return c.waitForAck(ctx, c.conn)
}
// Auth authenticates the consumer with the broker
func (c *Consumer) Auth(ctx context.Context, username, password string) error {
authPayload := map[string]string{
"username": username,
"password": password,
}
payload, err := json.Marshal(authPayload)
if err != nil {
return err
}
headers := HeadersWithConsumerID(ctx, c.id)
msg, err := codec.NewMessage(consts.AUTH, payload, "", headers)
if err != nil {
return fmt.Errorf("error creating auth message: %v", err)
}
if err := c.send(ctx, c.conn, msg); err != nil {
return fmt.Errorf("error sending auth: %v", err)
}
// Wait for AUTH_ACK
resp, err := c.receive(ctx, c.conn)
if err != nil {
return fmt.Errorf("error receiving auth response: %v", err)
}
if resp.Command != consts.AUTH_ACK {
return fmt.Errorf("authentication failed: %s", string(resp.Payload))
}
return nil
}
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
fmt.Println("Consumer closed")
return nil
@@ -559,6 +593,16 @@ func (c *Consumer) Consume(ctx context.Context) error {
return fmt.Errorf("initial connection failed: %w", err)
}
// Authenticate if security is enabled
if c.opts.enableSecurity {
if c.opts.username == "" || c.opts.password == "" {
return fmt.Errorf("username and password required for authentication")
}
if err := c.Auth(ctx, c.opts.username, c.opts.password); err != nil {
return fmt.Errorf("authentication failed: %w", err)
}
}
// Initialize pool
c.pool = NewPool(
c.opts.numOfWorkers,

21
examples/consumer.go Normal file
View File

@@ -0,0 +1,21 @@
package main
import (
"context"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks"
)
func main() {
n := &tasks.Node6{}
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask,
mq.WithBrokerURL(":8081"),
mq.WithHTTPApi(true),
mq.WithWorkerPool(100, 4, 50000),
mq.WithSecurity(true),
mq.WithUsername("consumer"),
mq.WithPassword("con123"),
)
consumer1.Consume(context.Background())
}

29
examples/publisher.go Normal file
View File

@@ -0,0 +1,29 @@
package main
import (
"context"
"fmt"
"github.com/oarkflow/mq"
)
func main() {
payload := []byte(`{"phone": "+123456789", "email": "abc.xyz@gmail.com", "age": 12}`)
task := mq.Task{
Payload: payload,
}
publisher := mq.NewPublisher("publish-1",
mq.WithBrokerURL(":8081"),
mq.WithSecurity(true),
mq.WithUsername("publisher"),
mq.WithPassword("pub123"),
)
for i := 0; i < 2; i++ {
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
err := publisher.Publish(context.Background(), task, "queue1")
if err != nil {
panic(err)
}
}
fmt.Println("Async task published successfully")
}

View File

@@ -1,604 +1,57 @@
// fast_http_router.go
// Ultra-high performance HTTP router in Go matching gofiber speed
// Key optimizations:
// - Zero allocations on hot path (no slice/map allocations per request)
// - Byte-based routing for maximum speed
// - Pre-allocated pools for everything
// - Minimal interface overhead
// - Direct memory operations where possible
package main
import (
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/oarkflow/mq"
"github.com/oarkflow/mq/examples/tasks"
)
// ----------------------------
// Public Interfaces (minimal overhead)
// ----------------------------
type HandlerFunc func(*Ctx) error
type Engine interface {
http.Handler
Group(prefix string, m ...HandlerFunc) RouteGroup
Use(m ...HandlerFunc)
GET(path string, h HandlerFunc)
POST(path string, h HandlerFunc)
PUT(path string, h HandlerFunc)
DELETE(path string, h HandlerFunc)
Static(prefix, root string)
ListenAndServe(addr string) error
Shutdown(ctx context.Context) error
func main() {
b := mq.NewBroker(
mq.WithCallback(tasks.Callback),
mq.WithBrokerURL(":8081"),
mq.WithSecurity(true),
// mq.WithMonitoring(true),
// mq.WithAdminAddr(":8080"),
// mq.WithMetricsAddr(":9090"),
)
InitializeDefaults(b.SecurityManager())
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
if err := b.InitializeSecurity(); err != nil {
panic(err)
}
b.NewQueue("queue1")
b.NewQueue("queue2")
b.StartEnhanced(context.Background())
}
type RouteGroup interface {
Use(m ...HandlerFunc)
GET(path string, h HandlerFunc)
POST(path string, h HandlerFunc)
PUT(path string, h HandlerFunc)
DELETE(path string, h HandlerFunc)
}
// ----------------------------
// Ultra-fast param extraction
// ----------------------------
type Param struct {
Key string
Value string
}
// Pre-allocated param slices to avoid any allocations
var paramPool = sync.Pool{
New: func() any {
return make([]Param, 0, 16)
},
}
// ----------------------------
// Context with zero allocations
// ----------------------------
type Ctx struct {
W http.ResponseWriter
Req *http.Request
params []Param
index int8
plen int8
// Embedded handler chain (no slice allocation)
handlers [16]HandlerFunc // fixed size, 99% of routes have < 16 handlers
hlen int8
status int
engine *engine
}
var ctxPool = sync.Pool{
New: func() any {
return &Ctx{}
},
}
func (c *Ctx) reset() {
c.W = nil
c.Req = nil
if c.params != nil {
paramPool.Put(c.params[:0])
c.params = nil
// InitializeDefaults adds default permissions, roles, and users for development/testing
func InitializeDefaults(sm *mq.SecurityManager) error {
permissions := []*mq.Permission{
{Name: "task.publish", Resource: "task", Action: "publish", Description: "Publish tasks to queues", CreatedAt: time.Now()},
{Name: "task.consume", Resource: "task", Action: "consume", Description: "Consume tasks from queues", CreatedAt: time.Now()},
{Name: "queue.manage", Resource: "queue", Action: "manage", Description: "Manage queues", CreatedAt: time.Now()},
{Name: "admin.system", Resource: "system", Action: "admin", Description: "System administration", CreatedAt: time.Now()},
}
c.index = 0
c.plen = 0
c.hlen = 0
c.status = 0
c.engine = nil
}
// Ultra-fast param lookup (linear search is faster than map for < 8 params)
func (c *Ctx) Param(key string) string {
for i := int8(0); i < c.plen; i++ {
if c.params[i].Key == key {
return c.params[i].Value
}
}
return ""
}
func (c *Ctx) addParam(key, value string) {
if c.params == nil {
c.params = paramPool.Get().([]Param)
}
if c.plen < 16 { // max 16 params
c.params = append(c.params, Param{Key: key, Value: value})
c.plen++
}
}
// Zero-allocation header operations
func (c *Ctx) Set(key, val string) {
if c.W != nil {
c.W.Header().Set(key, val)
}
}
func (c *Ctx) Get(key string) string {
if c.Req != nil {
return c.Req.Header.Get(key)
}
return ""
}
// Ultra-fast response methods
func (c *Ctx) SendString(s string) error {
if c.status != 0 {
c.W.WriteHeader(c.status)
}
_, err := io.WriteString(c.W, s)
return err
}
func (c *Ctx) JSON(v any) error {
c.Set("Content-Type", "application/json")
if c.status != 0 {
c.W.WriteHeader(c.status)
}
return json.NewEncoder(c.W).Encode(v)
}
func (c *Ctx) Status(code int) { c.status = code }
func (c *Ctx) Next() error {
for c.index < c.hlen {
h := c.handlers[c.index]
c.index++
if err := h(c); err != nil {
return err
}
}
return nil
}
// ----------------------------
// Ultra-fast byte-based router
// ----------------------------
type methodType uint8
const (
methodGet methodType = iota
methodPost
methodPut
methodDelete
methodOptions
methodHead
methodPatch
)
var methodMap = map[string]methodType{
"GET": methodGet,
"POST": methodPost,
"PUT": methodPut,
"DELETE": methodDelete,
"OPTIONS": methodOptions,
"HEAD": methodHead,
"PATCH": methodPatch,
}
// Route info with pre-computed handler chain
type route struct {
handlers [16]HandlerFunc
hlen int8
}
// Ultra-fast trie node
type node struct {
// Static children - direct byte lookup for first character
static [256]*node
// Dynamic children
param *node
wildcard *node
// Route data
routes [8]*route // index by method type
// Node metadata
paramName string
isEnd bool
}
// Path parsing with zero allocations
func splitPathFast(path string) []string {
if path == "/" {
return nil
}
// Count segments first
count := 0
start := 1 // skip leading /
for i := start; i < len(path); i++ {
if path[i] == '/' {
count++
}
}
count++ // last segment
// Pre-allocate exact size
segments := make([]string, 0, count)
start = 1
for i := 1; i <= len(path); i++ {
if i == len(path) || path[i] == '/' {
if i > start {
segments = append(segments, path[start:i])
}
start = i + 1
}
}
return segments
}
// Add route with minimal allocations
func (n *node) addRoute(method methodType, segments []string, handlers []HandlerFunc) {
curr := n
for _, seg := range segments {
if len(seg) == 0 {
continue
}
if seg[0] == ':' {
// Parameter route
if curr.param == nil {
curr.param = &node{paramName: seg[1:]}
}
curr = curr.param
} else if seg[0] == '*' {
// Wildcard route
if curr.wildcard == nil {
curr.wildcard = &node{paramName: seg[1:]}
}
curr = curr.wildcard
break // wildcard consumes rest
} else {
// Static route - use first byte for O(1) lookup
firstByte := seg[0]
if curr.static[firstByte] == nil {
curr.static[firstByte] = &node{}
}
curr = curr.static[firstByte]
}
}
curr.isEnd = true
// Store pre-computed handler chain
if curr.routes[method] == nil {
curr.routes[method] = &route{}
}
r := curr.routes[method]
r.hlen = 0
for i, h := range handlers {
if i >= 16 {
break // max 16 handlers
}
r.handlers[i] = h
r.hlen++
}
}
// Ultra-fast route matching
func (n *node) match(segments []string, params []Param, plen *int8) (*route, methodType, bool) {
curr := n
for i, seg := range segments {
if len(seg) == 0 {
continue
}
// Try static first (O(1) lookup)
firstByte := seg[0]
if next := curr.static[firstByte]; next != nil {
curr = next
continue
}
// Try parameter
if curr.param != nil {
if *plen < 16 {
params[*plen] = Param{Key: curr.param.paramName, Value: seg}
(*plen)++
}
curr = curr.param
continue
}
// Try wildcard
if curr.wildcard != nil {
if *plen < 16 {
// Wildcard captures remaining path
remaining := strings.Join(segments[i:], "/")
params[*plen] = Param{Key: curr.wildcard.paramName, Value: remaining}
(*plen)++
}
curr = curr.wildcard
break
}
return nil, 0, false
}
if !curr.isEnd {
return nil, 0, false
}
// Find method (most common methods first)
if r := curr.routes[methodGet]; r != nil {
return r, methodGet, true
}
if r := curr.routes[methodPost]; r != nil {
return r, methodPost, true
}
if r := curr.routes[methodPut]; r != nil {
return r, methodPut, true
}
if r := curr.routes[methodDelete]; r != nil {
return r, methodDelete, true
}
return nil, 0, false
}
// ----------------------------
// Engine implementation
// ----------------------------
type engine struct {
tree *node
middleware []HandlerFunc
servers []*http.Server
shutdown int32
}
func New() Engine {
return &engine{
tree: &node{},
}
}
// Ultra-fast request handling
func (e *engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if atomic.LoadInt32(&e.shutdown) == 1 {
w.WriteHeader(503)
return
}
// Get context from pool
c := ctxPool.Get().(*Ctx)
c.reset()
c.W = w
c.Req = r
c.engine = e
// Parse path once
segments := splitPathFast(r.URL.Path)
// Pre-allocated param array (on stack)
var paramArray [16]Param
var plen int8
// Match route
route, _, found := e.tree.match(segments, paramArray[:], &plen)
if !found {
w.WriteHeader(404)
w.Write([]byte("404"))
ctxPool.Put(c)
return
}
// Set params (no allocation)
if plen > 0 {
c.params = paramPool.Get().([]Param)
for i := int8(0); i < plen; i++ {
c.params = append(c.params, paramArray[i])
}
c.plen = plen
}
// Copy handlers (no allocation - fixed array)
copy(c.handlers[:], route.handlers[:route.hlen])
c.hlen = route.hlen
// Execute
if err := c.Next(); err != nil {
w.WriteHeader(500)
}
ctxPool.Put(c)
}
func (e *engine) Use(m ...HandlerFunc) {
e.middleware = append(e.middleware, m...)
}
func (e *engine) addRoute(method, path string, groupMiddleware []HandlerFunc, h HandlerFunc) {
mt, ok := methodMap[method]
if !ok {
return
}
segments := splitPathFast(path)
// Build handler chain: global + group + route
totalLen := len(e.middleware) + len(groupMiddleware) + 1
if totalLen > 16 {
totalLen = 16 // max handlers
}
handlers := make([]HandlerFunc, 0, totalLen)
handlers = append(handlers, e.middleware...)
handlers = append(handlers, groupMiddleware...)
handlers = append(handlers, h)
e.tree.addRoute(mt, segments, handlers)
}
func (e *engine) GET(path string, h HandlerFunc) { e.addRoute("GET", path, nil, h) }
func (e *engine) POST(path string, h HandlerFunc) { e.addRoute("POST", path, nil, h) }
func (e *engine) PUT(path string, h HandlerFunc) { e.addRoute("PUT", path, nil, h) }
func (e *engine) DELETE(path string, h HandlerFunc) { e.addRoute("DELETE", path, nil, h) }
// RouteGroup implementation
type routeGroup struct {
prefix string
engine *engine
middleware []HandlerFunc
}
func (e *engine) Group(prefix string, m ...HandlerFunc) RouteGroup {
return &routeGroup{
prefix: prefix,
engine: e,
middleware: m,
}
}
func (g *routeGroup) Use(m ...HandlerFunc) { g.middleware = append(g.middleware, m...) }
func (g *routeGroup) add(method, path string, h HandlerFunc) {
fullPath := g.prefix + path
g.engine.addRoute(method, fullPath, g.middleware, h)
}
func (g *routeGroup) GET(path string, h HandlerFunc) { g.add("GET", path, h) }
func (g *routeGroup) POST(path string, h HandlerFunc) { g.add("POST", path, h) }
func (g *routeGroup) PUT(path string, h HandlerFunc) { g.add("PUT", path, h) }
func (g *routeGroup) DELETE(path string, h HandlerFunc) { g.add("DELETE", path, h) }
// Ultra-fast static file serving
func (e *engine) Static(prefix, root string) {
if !strings.HasSuffix(prefix, "/") {
prefix += "/"
}
e.GET(strings.TrimSuffix(prefix, "/"), func(c *Ctx) error {
path := root + "/"
http.ServeFile(c.W, c.Req, path)
return nil
})
e.GET(prefix+"*", func(c *Ctx) error {
filepath := c.Param("")
if filepath == "" {
filepath = "/"
}
path := root + "/" + filepath
http.ServeFile(c.W, c.Req, path)
return nil
})
}
func (e *engine) ListenAndServe(addr string) error {
srv := &http.Server{Addr: addr, Handler: e}
e.servers = append(e.servers, srv)
return srv.ListenAndServe()
}
func (e *engine) Shutdown(ctx context.Context) error {
atomic.StoreInt32(&e.shutdown, 1)
for _, srv := range e.servers {
srv.Shutdown(ctx)
}
return nil
}
// ----------------------------
// Middleware
// ----------------------------
func Recover() HandlerFunc {
return func(c *Ctx) error {
defer func() {
if r := recover(); r != nil {
log.Printf("panic: %v", r)
c.Status(500)
c.SendString("Internal Server Error")
}
}()
return c.Next()
}
}
func Logger() HandlerFunc {
return func(c *Ctx) error {
start := time.Now()
err := c.Next()
log.Printf("%s %s %v", c.Req.Method, c.Req.URL.Path, time.Since(start))
err := sm.AddPermissions(permissions...)
if err != nil {
return err
}
}
// ----------------------------
// Example
// ----------------------------
func mai3n() {
app := New()
app.Use(Recover())
app.GET("/", func(c *Ctx) error {
return c.SendString("Hello World!")
})
app.GET("/user/:id", func(c *Ctx) error {
return c.SendString("User: " + c.Param("id"))
})
api := app.Group("/api")
api.GET("/ping", func(c *Ctx) error {
return c.JSON(map[string]any{"message": "pong"})
})
app.Static("/static", "public")
fmt.Println("Server starting on :8080")
if err := app.ListenAndServe(":8080"); err != nil {
log.Fatal(err)
roles := []*mq.Role{
{Name: "publisher", Description: "Can publish tasks", Permissions: []string{"task.publish"}, CreatedAt: time.Now()},
{Name: "consumer", Description: "Can consume tasks", Permissions: []string{"task.consume"}, CreatedAt: time.Now()},
{Name: "admin", Description: "Full system access", Permissions: []string{"task.publish", "task.consume", "queue.manage", "admin.system"}, CreatedAt: time.Now()},
}
err = sm.AddRoles(roles...)
if err != nil {
return err
}
users := []*mq.User{
{ID: "admin", Username: "admin", Roles: []string{"admin"}, CreatedAt: time.Now(), Password: "admin123"},
{ID: "publisher", Username: "publisher", Roles: []string{"publisher"}, CreatedAt: time.Now(), Password: "pub123"},
{ID: "consumer", Username: "consumer", Roles: []string{"consumer"}, CreatedAt: time.Now(), Password: "con123"},
}
return sm.AddUsers(users...)
}
// ----------------------------
// Performance optimizations:
// ----------------------------
// 1. Zero allocations on hot path:
// - Fixed-size arrays instead of slices for handlers/params
// - Stack-allocated param arrays
// - Byte-based trie with O(1) static lookups
// - Pre-allocated pools for everything
//
// 2. Minimal interface overhead:
// - Direct memory operations
// - Embedded handler chains in context
// - Method type enum instead of string comparisons
//
// 3. Optimized data structures:
// - 256-element array for O(1) first-byte lookup
// - Linear search for params (faster than map for < 8 items)
// - Pre-computed route chains stored in trie
//
// 4. Fast path parsing:
// - Single-pass path splitting
// - Zero-allocation string operations
// - Minimal string comparisons
//
// This implementation should now match gofiber's performance by using
// similar zero-allocation techniques and optimized data structures.

187
mq.go
View File

@@ -280,6 +280,12 @@ type Options struct {
BrokerRateLimiter *RateLimiter // new field for broker rate limiting
ConsumerRateLimiter *RateLimiter // new field for consumer rate limiting
consumerTimeout time.Duration // timeout for consumer message processing (0 = no timeout)
adminAddr string // address for admin server
metricsAddr string // address for metrics server
enableSecurity bool // enable security features
enableMonitoring bool // enable monitoring features
username string // username for authentication
password string // password for authentication
}
func (o *Options) SetSyncMode(sync bool) {
@@ -467,15 +473,19 @@ type Broker struct {
listener net.Listener
// Enhanced production features
connectionPool *ConnectionPool
healthChecker *HealthChecker
circuitBreaker *EnhancedCircuitBreaker
metricsCollector *MetricsCollector
messageStore MessageStore
isShutdown int32
shutdown chan struct{}
wg sync.WaitGroup
logger logger.Logger
connectionPool *ConnectionPool
healthChecker *HealthChecker
circuitBreaker *EnhancedCircuitBreaker
metricsCollector *MetricsCollector
messageStore MessageStore
securityManager *SecurityManager
adminServer *AdminServer
metricsServer *MetricsServer
authenticatedConns storage.IMap[string, bool] // authenticated connections
isShutdown int32
shutdown chan struct{}
wg sync.WaitGroup
logger logger.Logger
}
func NewBroker(opts ...Option) *Broker {
@@ -491,13 +501,38 @@ func NewBroker(opts ...Option) *Broker {
opts: options,
// Enhanced production features
connectionPool: NewConnectionPool(1000), // max 1000 connections
healthChecker: NewHealthChecker(),
circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout
metricsCollector: NewMetricsCollector(),
messageStore: NewInMemoryMessageStore(),
shutdown: make(chan struct{}),
logger: options.Logger(),
connectionPool: NewConnectionPool(1000), // max 1000 connections
healthChecker: NewHealthChecker(),
circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout
metricsCollector: NewMetricsCollector(),
messageStore: NewInMemoryMessageStore(),
authenticatedConns: memory.New[string, bool](),
shutdown: make(chan struct{}),
logger: options.Logger(),
}
if options.enableSecurity {
broker.securityManager = NewSecurityManager()
}
if options.enableMonitoring {
if options.adminAddr != "" {
broker.adminServer = NewAdminServer(broker, options.adminAddr, options.Logger())
}
if options.metricsAddr != "" {
// Need to create MonitoringConfig, use default
config := &MonitoringConfig{
EnableMetrics: true,
MetricsPort: 9090, // default
MetricsPath: "/metrics",
EnableHealthCheck: true,
HealthCheckPort: 8080,
HealthCheckPath: "/health",
HealthCheckInterval: time.Minute,
EnableLogging: true,
LogLevel: "info",
}
broker.metricsServer = NewMetricsServer(broker, config, options.Logger())
}
}
broker.healthChecker.broker = broker
@@ -508,6 +543,18 @@ func (b *Broker) Options() *Options {
return b.opts
}
func (b *Broker) SecurityManager() *SecurityManager {
return b.securityManager
}
// InitializeSecurity initializes default users, roles, and permissions for development/testing
func (b *Broker) InitializeSecurity() error {
if b.securityManager == nil {
return fmt.Errorf("security manager not initialized")
}
return nil
}
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
consumerID, ok := GetConsumerID(ctx)
if ok && consumerID != "" {
@@ -553,6 +600,10 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
b.publishers.Del(publisherID)
}
}
// Remove from authenticated connections
connID := conn.RemoteAddr().String()
b.authenticatedConns.Del(connID)
log.Printf("BROKER - Connection closed: address %s", conn.RemoteAddr())
return nil
}
@@ -563,8 +614,29 @@ func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
}
}
func (b *Broker) isAuthenticated(connID string) bool {
if b.securityManager == nil {
return true // no security, allow all
}
_, ok := b.authenticatedConns.Get(connID)
return ok
}
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
connID := conn.RemoteAddr().String()
// Check authentication for protected commands
if b.securityManager != nil && (msg.Command == consts.PUBLISH || msg.Command == consts.SUBSCRIBE) {
if !b.isAuthenticated(connID) {
b.logger.Warn("Unauthenticated access attempt", logger.Field{Key: "command", Value: msg.Command.String()}, logger.Field{Key: "conn", Value: connID})
// Send error response
return
}
}
switch msg.Command {
case consts.AUTH:
b.AuthHandler(ctx, conn, msg)
case consts.PUBLISH:
b.PublishHandler(ctx, conn, msg)
case consts.SUBSCRIBE:
@@ -588,6 +660,54 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con
}
}
func (b *Broker) AuthHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
connID := conn.RemoteAddr().String()
// Parse auth credentials from payload
var authReq map[string]any
if err := json.Unmarshal(msg.Payload, &authReq); err != nil {
b.logger.Error("Invalid auth request", logger.Field{Key: "error", Value: err.Error()})
return
}
username, _ := authReq["username"].(string)
password, _ := authReq["password"].(string)
// Authenticate
user, err := b.securityManager.Authenticate(ctx, map[string]any{
"username": username,
"password": password,
})
if err != nil {
b.logger.Warn("Authentication failed", logger.Field{Key: "username", Value: username}, logger.Field{Key: "conn", Value: connID})
// Send AUTH_DENY
denyMsg, err := codec.NewMessage(consts.AUTH_DENY, []byte(fmt.Sprintf(`{"error":"%s"}`, err.Error())), "", msg.Headers)
if err != nil {
b.logger.Error("Failed to create AUTH_DENY message", logger.Field{Key: "error", Value: err.Error()})
return
}
if err := b.send(ctx, conn, denyMsg); err != nil {
b.logger.Error("Failed to send AUTH_DENY", logger.Field{Key: "error", Value: err.Error()})
}
return
}
// Mark as authenticated
b.authenticatedConns.Set(connID, true)
// Send AUTH_ACK
ackMsg, err := codec.NewMessage(consts.AUTH_ACK, []byte(`{"status":"authenticated"}`), "", msg.Headers)
if err != nil {
b.logger.Error("Failed to create AUTH_ACK message", logger.Field{Key: "error", Value: err.Error()})
return
}
if err := b.send(ctx, conn, ackMsg); err != nil {
b.logger.Error("Failed to send AUTH_ACK", logger.Field{Key: "error", Value: err.Error()})
}
b.logger.Info("User authenticated", logger.Field{Key: "username", Value: user.Username}, logger.Field{Key: "conn", Value: connID})
}
func (b *Broker) AdjustConsumerWorkers(noOfWorkers int, consumerID ...string) {
b.consumers.ForEach(func(_ string, c *consumer) bool {
return true
@@ -1498,10 +1618,43 @@ func (b *Broker) StartEnhanced(ctx context.Context) error {
b.wg.Add(1)
go b.messageStoreCleanupRoutine()
// Start admin server if enabled
if b.adminServer != nil {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if err := b.adminServer.Start(); err != nil {
b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()})
}
}()
}
// Start metrics server if enabled
if b.metricsServer != nil {
b.wg.Add(1)
go func() {
defer b.wg.Done()
if err := b.metricsServer.Start(ctx); err != nil {
b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()})
}
}()
}
b.logger.Info("Enhanced broker starting with production features enabled")
// Start the enhanced broker with its own implementation
return b.startEnhancedBroker(ctx)
if err := b.startEnhancedBroker(ctx); err != nil {
return err
}
// Wait for shutdown signal
<-b.shutdown
b.logger.Info("Enhanced broker shutting down")
// Wait for all goroutines to finish
b.wg.Wait()
return nil
}
// startEnhancedBroker starts the core broker functionality

View File

@@ -314,3 +314,39 @@ func WithConsumerTimeout(timeout time.Duration) Option {
opts.consumerTimeout = timeout
}
}
func WithAdminAddr(addr string) Option {
return func(opts *Options) {
opts.adminAddr = addr
}
}
func WithMetricsAddr(addr string) Option {
return func(opts *Options) {
opts.metricsAddr = addr
}
}
func WithSecurity(enabled bool) Option {
return func(opts *Options) {
opts.enableSecurity = enabled
}
}
func WithMonitoring(enabled bool) Option {
return func(opts *Options) {
opts.enableMonitoring = enabled
}
}
func WithUsername(username string) Option {
return func(opts *Options) {
opts.username = username
}
}
func WithPassword(password string) Option {
return func(opts *Options) {
opts.password = password
}
}

View File

@@ -18,10 +18,11 @@ import (
)
type Publisher struct {
opts *Options
id string
conn net.Conn
connLock sync.Mutex
opts *Options
id string
conn net.Conn
connLock sync.Mutex
authenticated bool
}
func NewPublisher(id string, opts ...Option) *Publisher {
@@ -60,12 +61,70 @@ func (p *Publisher) ensureConnection(ctx context.Context) error {
return fmt.Errorf("failed to connect to broker after retries: %w", err)
}
// Auth authenticates the publisher with the broker
func (p *Publisher) Auth(ctx context.Context, username, password string) error {
if err := p.ensureConnection(ctx); err != nil {
return err
}
p.connLock.Lock()
conn := p.conn
p.connLock.Unlock()
authPayload := map[string]string{
"username": username,
"password": password,
}
payload, err := json.Marshal(authPayload)
if err != nil {
return err
}
headers := map[string]string{
consts.PublisherKey: p.id,
consts.ContentType: consts.TypeJson,
}
msg, err := codec.NewMessage(consts.AUTH, payload, "", headers)
if err != nil {
return err
}
err = codec.SendMessage(ctx, conn, msg)
if err != nil {
return err
}
// Wait for AUTH_ACK
resp, err := codec.ReadMessage(ctx, conn)
if err != nil {
return err
}
if resp.Command != consts.AUTH_ACK {
return fmt.Errorf("authentication failed: %s", string(resp.Payload))
}
return nil
}
// Publish method that uses the persistent connection.
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
// Ensure connection is established.
if err := p.ensureConnection(ctx); err != nil {
return err
}
// Authenticate if security is enabled and not already authenticated
if p.opts.enableSecurity && !p.authenticated {
if p.opts.username == "" || p.opts.password == "" {
return fmt.Errorf("username and password required for authentication")
}
if err := p.Auth(ctx, p.opts.username, p.opts.password); err != nil {
return fmt.Errorf("authentication failed: %w", err)
}
p.authenticated = true
}
delay := p.opts.initialDelay
for i := 0; i < p.opts.maxRetries; i++ {
// Use the persistent connection.

View File

@@ -10,8 +10,318 @@ import (
"strings"
"sync"
"time"
"github.com/oarkflow/squealx"
)
// Storage interfaces for persistence
type UserStorage interface {
GetUser(username string) (*User, error)
GetPassword(username string) (string, error)
SaveUser(user *User, password string) error
ListUsers() ([]*User, error)
DeleteUser(username string) error
}
type RoleStorage interface {
GetRole(name string) (*Role, error)
SaveRole(role *Role) error
ListRoles() ([]*Role, error)
DeleteRole(name string) error
}
type PermissionStorage interface {
GetPermission(name string) (*Permission, error)
SavePermission(perm *Permission) error
ListPermissions() ([]*Permission, error)
DeletePermission(name string) error
}
// MemoryUserStorage implements UserStorage using in-memory storage
type MemoryUserStorage struct {
users map[string]*User
passwords map[string]string
mu sync.RWMutex
}
func NewMemoryUserStorage() *MemoryUserStorage {
return &MemoryUserStorage{
users: make(map[string]*User),
passwords: make(map[string]string),
}
}
func (mus *MemoryUserStorage) GetUser(username string) (*User, error) {
mus.mu.RLock()
defer mus.mu.RUnlock()
user, exists := mus.users[username]
if !exists {
return nil, fmt.Errorf("user not found")
}
return user, nil
}
func (mus *MemoryUserStorage) SaveUser(user *User, password string) error {
mus.mu.Lock()
defer mus.mu.Unlock()
mus.users[user.Username] = user
mus.passwords[user.Username] = password
return nil
}
func (mus *MemoryUserStorage) ListUsers() ([]*User, error) {
mus.mu.RLock()
defer mus.mu.RUnlock()
users := make([]*User, 0, len(mus.users))
for _, user := range mus.users {
users = append(users, user)
}
return users, nil
}
func (mus *MemoryUserStorage) GetPassword(username string) (string, error) {
mus.mu.RLock()
defer mus.mu.RUnlock()
password, exists := mus.passwords[username]
if !exists {
return "", fmt.Errorf("password not found")
}
return password, nil
}
func (mus *MemoryUserStorage) DeleteUser(username string) error {
mus.mu.Lock()
defer mus.mu.Unlock()
delete(mus.users, username)
delete(mus.passwords, username)
return nil
}
// MemoryRoleStorage implements RoleStorage using in-memory storage
type MemoryRoleStorage struct {
roles map[string]*Role
mu sync.RWMutex
}
func NewMemoryRoleStorage() *MemoryRoleStorage {
return &MemoryRoleStorage{
roles: make(map[string]*Role),
}
}
func (mrs *MemoryRoleStorage) GetRole(name string) (*Role, error) {
mrs.mu.RLock()
defer mrs.mu.RUnlock()
role, exists := mrs.roles[name]
if !exists {
return nil, fmt.Errorf("role not found")
}
return role, nil
}
func (mrs *MemoryRoleStorage) SaveRole(role *Role) error {
mrs.mu.Lock()
defer mrs.mu.Unlock()
mrs.roles[role.Name] = role
return nil
}
func (mrs *MemoryRoleStorage) ListRoles() ([]*Role, error) {
mrs.mu.RLock()
defer mrs.mu.RUnlock()
roles := make([]*Role, 0, len(mrs.roles))
for _, role := range mrs.roles {
roles = append(roles, role)
}
return roles, nil
}
func (mrs *MemoryRoleStorage) DeleteRole(name string) error {
mrs.mu.Lock()
defer mrs.mu.Unlock()
delete(mrs.roles, name)
return nil
}
// MemoryPermissionStorage implements PermissionStorage using in-memory storage
type MemoryPermissionStorage struct {
permissions map[string]*Permission
mu sync.RWMutex
}
func NewMemoryPermissionStorage() *MemoryPermissionStorage {
return &MemoryPermissionStorage{
permissions: make(map[string]*Permission),
}
}
func (mps *MemoryPermissionStorage) GetPermission(name string) (*Permission, error) {
mps.mu.RLock()
defer mps.mu.RUnlock()
perm, exists := mps.permissions[name]
if !exists {
return nil, fmt.Errorf("permission not found")
}
return perm, nil
}
func (mps *MemoryPermissionStorage) SavePermission(perm *Permission) error {
mps.mu.Lock()
defer mps.mu.Unlock()
mps.permissions[perm.Name] = perm
return nil
}
func (mps *MemoryPermissionStorage) ListPermissions() ([]*Permission, error) {
mps.mu.RLock()
defer mps.mu.RUnlock()
perms := make([]*Permission, 0, len(mps.permissions))
for _, perm := range mps.permissions {
perms = append(perms, perm)
}
return perms, nil
}
func (mps *MemoryPermissionStorage) DeletePermission(name string) error {
mps.mu.Lock()
defer mps.mu.Unlock()
delete(mps.permissions, name)
return nil
}
// SQLUserStorage implements UserStorage using SQL database
type SQLUserStorage struct {
db *squealx.DB
}
func NewSQLUserStorage(db *squealx.DB) *SQLUserStorage {
return &SQLUserStorage{db: db}
}
func (sus *SQLUserStorage) GetUser(username string) (*User, error) {
var user User
err := sus.db.Get(&user, "SELECT id, username, roles, permissions, metadata, created_at, last_login_at FROM users WHERE username = $1", username)
if err != nil {
return nil, err
}
return &user, nil
}
func (sus *SQLUserStorage) GetPassword(username string) (string, error) {
var password string
err := sus.db.Get(&password, "SELECT password FROM users WHERE username = $1", username)
if err != nil {
return "", err
}
return password, nil
}
func (sus *SQLUserStorage) SaveUser(user *User, password string) error {
_, err := sus.db.Exec(`
INSERT INTO users (id, username, password, roles, permissions, metadata, created_at, last_login_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (username) DO UPDATE SET
password = EXCLUDED.password,
roles = EXCLUDED.roles,
permissions = EXCLUDED.permissions,
metadata = EXCLUDED.metadata,
last_login_at = EXCLUDED.last_login_at`,
user.ID, user.Username, password, user.Roles, user.Permissions, user.Metadata, user.CreatedAt, user.LastLoginAt)
return err
}
func (sus *SQLUserStorage) ListUsers() ([]*User, error) {
var users []*User
err := sus.db.Select(&users, "SELECT id, username, roles, permissions, metadata, created_at, last_login_at FROM users")
return users, err
}
func (sus *SQLUserStorage) DeleteUser(username string) error {
_, err := sus.db.Exec("DELETE FROM users WHERE username = $1", username)
return err
}
// SQLRoleStorage implements RoleStorage using SQL database
type SQLRoleStorage struct {
db *squealx.DB
}
func NewSQLRoleStorage(db *squealx.DB) *SQLRoleStorage {
return &SQLRoleStorage{db: db}
}
func (srs *SQLRoleStorage) GetRole(name string) (*Role, error) {
var role Role
err := srs.db.Get(&role, "SELECT name, description, permissions, created_at FROM roles WHERE name = $1", name)
if err != nil {
return nil, err
}
return &role, nil
}
func (srs *SQLRoleStorage) SaveRole(role *Role) error {
_, err := srs.db.Exec(`
INSERT INTO roles (name, description, permissions, created_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (name) DO UPDATE SET
description = EXCLUDED.description,
permissions = EXCLUDED.permissions`,
role.Name, role.Description, role.Permissions, role.CreatedAt)
return err
}
func (srs *SQLRoleStorage) ListRoles() ([]*Role, error) {
var roles []*Role
err := srs.db.Select(&roles, "SELECT name, description, permissions, created_at FROM roles")
return roles, err
}
func (srs *SQLRoleStorage) DeleteRole(name string) error {
_, err := srs.db.Exec("DELETE FROM roles WHERE name = $1", name)
return err
}
// SQLPermissionStorage implements PermissionStorage using SQL database
type SQLPermissionStorage struct {
db *squealx.DB
}
func NewSQLPermissionStorage(db *squealx.DB) *SQLPermissionStorage {
return &SQLPermissionStorage{db: db}
}
func (sps *SQLPermissionStorage) GetPermission(name string) (*Permission, error) {
var perm Permission
err := sps.db.Get(&perm, "SELECT name, resource, action, description, created_at FROM permissions WHERE name = $1", name)
if err != nil {
return nil, err
}
return &perm, nil
}
func (sps *SQLPermissionStorage) SavePermission(perm *Permission) error {
_, err := sps.db.Exec(`
INSERT INTO permissions (name, resource, action, description, created_at)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (name) DO UPDATE SET
resource = EXCLUDED.resource,
action = EXCLUDED.action,
description = EXCLUDED.description`,
perm.Name, perm.Resource, perm.Action, perm.Description, perm.CreatedAt)
return err
}
func (sps *SQLPermissionStorage) ListPermissions() ([]*Permission, error) {
var perms []*Permission
err := sps.db.Select(&perms, "SELECT name, resource, action, description, created_at FROM permissions")
return perms, err
}
func (sps *SQLPermissionStorage) DeletePermission(name string) error {
_, err := sps.db.Exec("DELETE FROM permissions WHERE name = $1", name)
return err
}
// SecurityManager handles authentication, authorization, and security policies
type SecurityManager struct {
authProviders map[string]AuthProvider
@@ -34,6 +344,7 @@ type AuthProvider interface {
type User struct {
ID string `json:"id"`
Username string `json:"username"`
Password string `json:"-"`
Roles []string `json:"roles"`
Permissions []string `json:"permissions"`
Metadata map[string]any `json:"metadata,omitempty"`
@@ -43,9 +354,9 @@ type User struct {
// RoleManager manages user roles and permissions
type RoleManager struct {
roles map[string]*Role
permissions map[string]*Permission
mu sync.RWMutex
roleStorage RoleStorage
permissionStorage PermissionStorage
mu sync.RWMutex
}
// Role represents a user role with associated permissions
@@ -123,93 +434,47 @@ func NewSecurityManager() *SecurityManager {
key := make([]byte, 32)
rand.Read(key)
return &SecurityManager{
// Create memory storages by default
userStorage := NewMemoryUserStorage()
roleStorage := NewMemoryRoleStorage()
permissionStorage := NewMemoryPermissionStorage()
sm := &SecurityManager{
authProviders: make(map[string]AuthProvider),
roleManager: NewRoleManager(),
roleManager: NewRoleManager(roleStorage, permissionStorage),
rateLimiter: NewSecurityRateLimiter(5, time.Minute*15), // 5 attempts per 15 minutes
auditLogger: NewAuditLogger(10000),
sessionManager: NewSessionManager(time.Hour * 24), // 24 hour sessions
encryptionKey: key,
}
// Add default basic auth provider
basicProvider := NewBasicAuthProvider(userStorage)
sm.AddAuthProvider(basicProvider)
return sm
}
// NewRoleManager creates a new role manager
func NewRoleManager() *RoleManager {
rm := &RoleManager{
roles: make(map[string]*Role),
permissions: make(map[string]*Permission),
func NewRoleManager(roleStorage RoleStorage, permissionStorage PermissionStorage) *RoleManager {
return &RoleManager{
roleStorage: roleStorage,
permissionStorage: permissionStorage,
}
// Initialize default permissions
rm.AddPermission(&Permission{
Name: "task.publish",
Resource: "task",
Action: "publish",
Description: "Publish tasks to queues",
CreatedAt: time.Now(),
})
rm.AddPermission(&Permission{
Name: "task.consume",
Resource: "task",
Action: "consume",
Description: "Consume tasks from queues",
CreatedAt: time.Now(),
})
rm.AddPermission(&Permission{
Name: "queue.manage",
Resource: "queue",
Action: "manage",
Description: "Manage queues",
CreatedAt: time.Now(),
})
rm.AddPermission(&Permission{
Name: "admin.system",
Resource: "system",
Action: "admin",
Description: "System administration",
CreatedAt: time.Now(),
})
// Initialize default roles
rm.AddRole(&Role{
Name: "publisher",
Description: "Can publish tasks",
Permissions: []string{"task.publish"},
CreatedAt: time.Now(),
})
rm.AddRole(&Role{
Name: "consumer",
Description: "Can consume tasks",
Permissions: []string{"task.consume"},
CreatedAt: time.Now(),
})
rm.AddRole(&Role{
Name: "admin",
Description: "Full system access",
Permissions: []string{"task.publish", "task.consume", "queue.manage", "admin.system"},
CreatedAt: time.Now(),
})
return rm
}
// AddPermission adds a permission to the role manager
func (rm *RoleManager) AddPermission(perm *Permission) {
func (rm *RoleManager) AddPermission(perm *Permission) error {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.permissions[perm.Name] = perm
return rm.permissionStorage.SavePermission(perm)
}
// AddRole adds a role to the role manager
func (rm *RoleManager) AddRole(role *Role) {
func (rm *RoleManager) AddRole(role *Role) error {
rm.mu.Lock()
defer rm.mu.Unlock()
rm.roles[role.Name] = role
return rm.roleStorage.SaveRole(role)
}
// HasPermission checks if a user has a specific permission
@@ -218,11 +483,13 @@ func (rm *RoleManager) HasPermission(user *User, permission string) bool {
defer rm.mu.RUnlock()
for _, roleName := range user.Roles {
if role, exists := rm.roles[roleName]; exists {
for _, perm := range role.Permissions {
if perm == permission {
return true
}
role, err := rm.roleStorage.GetRole(roleName)
if err != nil {
continue
}
for _, perm := range role.Permissions {
if perm == permission {
return true
}
}
}
@@ -236,10 +503,12 @@ func (rm *RoleManager) GetUserPermissions(user *User) []string {
permissions := make(map[string]bool)
for _, roleName := range user.Roles {
if role, exists := rm.roles[roleName]; exists {
for _, perm := range role.Permissions {
permissions[perm] = true
}
role, err := rm.roleStorage.GetRole(roleName)
if err != nil {
continue
}
for _, perm := range role.Permissions {
permissions[perm] = true
}
}
@@ -469,6 +738,60 @@ func (sm *SecurityManager) Authenticate(ctx context.Context, credentials map[str
return nil, fmt.Errorf("authentication failed: %w", lastErr)
}
func (sm *SecurityManager) AddPermission(perm *Permission) error {
if perm == nil || (perm.Name == "" && (perm.Resource == "" || perm.Action == "")) {
return fmt.Errorf("invalid permission")
}
return sm.roleManager.AddPermission(perm)
}
func (sm *SecurityManager) AddRole(role *Role) error {
if role == nil || role.Name == "" {
return fmt.Errorf("invalid role")
}
return sm.roleManager.AddRole(role)
}
func (sm *SecurityManager) AddUsers(users ...*User) error {
for _, user := range users {
if user == nil || user.Username == "" || user.Password == "" {
return fmt.Errorf("invalid user")
}
if err := sm.AddUser(user); err != nil {
return err
}
}
return nil
}
func (sm *SecurityManager) AddRoles(roles ...*Role) error {
for _, role := range roles {
if err := sm.AddRole(role); err != nil {
return err
}
}
return nil
}
func (sm *SecurityManager) AddPermissions(perms ...*Permission) error {
for _, perm := range perms {
if err := sm.AddPermission(perm); err != nil {
return err
}
}
return nil
}
// AddUser adds a user to the system
func (sm *SecurityManager) AddUser(user *User) error {
for _, provider := range sm.authProviders {
if bap, ok := provider.(*BasicAuthProvider); ok {
return bap.AddUser(user, user.Password)
}
}
return fmt.Errorf("no suitable auth provider found")
}
// Authorize checks if a user is authorized for an action
func (sm *SecurityManager) Authorize(user *User, resource, action string) error {
permission := fmt.Sprintf("%s.%s", resource, action)
@@ -551,13 +874,12 @@ func (sm *SecurityManager) Decrypt(data []byte) ([]byte, error) {
// BasicAuthProvider implements basic username/password authentication
type BasicAuthProvider struct {
users map[string]*User
mu sync.RWMutex
userStorage UserStorage
}
func NewBasicAuthProvider() *BasicAuthProvider {
func NewBasicAuthProvider(userStorage UserStorage) *BasicAuthProvider {
return &BasicAuthProvider{
users: make(map[string]*User),
userStorage: userStorage,
}
}
@@ -576,16 +898,13 @@ func (bap *BasicAuthProvider) Authenticate(ctx context.Context, credentials map[
return nil, fmt.Errorf("password required")
}
bap.mu.RLock()
user, exists := bap.users[username]
bap.mu.RUnlock()
if !exists {
user, err := bap.userStorage.GetUser(username)
if err != nil {
return nil, fmt.Errorf("user not found")
}
// In production, compare hashed passwords
if password != "password" { // Placeholder
storedPassword, err := bap.userStorage.GetPassword(username)
if err != nil || storedPassword != password {
return nil, fmt.Errorf("invalid password")
}
@@ -611,14 +930,9 @@ func (bap *BasicAuthProvider) ValidateToken(token string) (*User, error) {
}
func (bap *BasicAuthProvider) AddUser(user *User, password string) error {
bap.mu.Lock()
defer bap.mu.Unlock()
// In production, hash the password
user.CreatedAt = time.Now()
bap.users[user.Username] = user
return nil
return bap.userStorage.SaveUser(user, password)
}
// generateID generates a random ID