mirror of
https://github.com/oarkflow/mq.git
synced 2025-09-26 20:11:16 +08:00
update
This commit is contained in:
@@ -445,7 +445,7 @@ func validateCommand(cmd consts.CMD) error {
|
||||
}
|
||||
|
||||
func validateQueue(queue string) error {
|
||||
if len(queue) == 0 || len(queue) > MaxQueueLength {
|
||||
if len(queue) > MaxQueueLength {
|
||||
return ErrInvalidQueue
|
||||
}
|
||||
return nil
|
||||
|
@@ -6,6 +6,9 @@ func (c CMD) IsValid() bool { return c >= PING && c <= CONSUMER_STOP }
|
||||
|
||||
const (
|
||||
PING CMD = iota + 1
|
||||
AUTH
|
||||
AUTH_ACK
|
||||
AUTH_DENY
|
||||
SUBSCRIBE
|
||||
SUBSCRIBE_ACK
|
||||
|
||||
@@ -42,6 +45,12 @@ func (c CMD) String() string {
|
||||
switch c {
|
||||
case PING:
|
||||
return "PING"
|
||||
case AUTH:
|
||||
return "AUTH"
|
||||
case AUTH_ACK:
|
||||
return "AUTH_ACK"
|
||||
case AUTH_DENY:
|
||||
return "AUTH_DENY"
|
||||
case SUBSCRIBE:
|
||||
return "SUBSCRIBE"
|
||||
case SUBSCRIBE_ACK:
|
||||
|
44
consumer.go
44
consumer.go
@@ -162,6 +162,40 @@ func (c *Consumer) subscribe(ctx context.Context, queue string) error {
|
||||
return c.waitForAck(ctx, c.conn)
|
||||
}
|
||||
|
||||
// Auth authenticates the consumer with the broker
|
||||
func (c *Consumer) Auth(ctx context.Context, username, password string) error {
|
||||
authPayload := map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
payload, err := json.Marshal(authPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headers := HeadersWithConsumerID(ctx, c.id)
|
||||
msg, err := codec.NewMessage(consts.AUTH, payload, "", headers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating auth message: %v", err)
|
||||
}
|
||||
|
||||
if err := c.send(ctx, c.conn, msg); err != nil {
|
||||
return fmt.Errorf("error sending auth: %v", err)
|
||||
}
|
||||
|
||||
// Wait for AUTH_ACK
|
||||
resp, err := c.receive(ctx, c.conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error receiving auth response: %v", err)
|
||||
}
|
||||
|
||||
if resp.Command != consts.AUTH_ACK {
|
||||
return fmt.Errorf("authentication failed: %s", string(resp.Payload))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Consumer) OnClose(_ context.Context, _ net.Conn) error {
|
||||
fmt.Println("Consumer closed")
|
||||
return nil
|
||||
@@ -559,6 +593,16 @@ func (c *Consumer) Consume(ctx context.Context) error {
|
||||
return fmt.Errorf("initial connection failed: %w", err)
|
||||
}
|
||||
|
||||
// Authenticate if security is enabled
|
||||
if c.opts.enableSecurity {
|
||||
if c.opts.username == "" || c.opts.password == "" {
|
||||
return fmt.Errorf("username and password required for authentication")
|
||||
}
|
||||
if err := c.Auth(ctx, c.opts.username, c.opts.password); err != nil {
|
||||
return fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize pool
|
||||
c.pool = NewPool(
|
||||
c.opts.numOfWorkers,
|
||||
|
21
examples/consumer.go
Normal file
21
examples/consumer.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
)
|
||||
|
||||
func main() {
|
||||
n := &tasks.Node6{}
|
||||
consumer1 := mq.NewConsumer("F", "queue1", n.ProcessTask,
|
||||
mq.WithBrokerURL(":8081"),
|
||||
mq.WithHTTPApi(true),
|
||||
mq.WithWorkerPool(100, 4, 50000),
|
||||
mq.WithSecurity(true),
|
||||
mq.WithUsername("consumer"),
|
||||
mq.WithPassword("con123"),
|
||||
)
|
||||
consumer1.Consume(context.Background())
|
||||
}
|
29
examples/publisher.go
Normal file
29
examples/publisher.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
)
|
||||
|
||||
func main() {
|
||||
payload := []byte(`{"phone": "+123456789", "email": "abc.xyz@gmail.com", "age": 12}`)
|
||||
task := mq.Task{
|
||||
Payload: payload,
|
||||
}
|
||||
publisher := mq.NewPublisher("publish-1",
|
||||
mq.WithBrokerURL(":8081"),
|
||||
mq.WithSecurity(true),
|
||||
mq.WithUsername("publisher"),
|
||||
mq.WithPassword("pub123"),
|
||||
)
|
||||
for i := 0; i < 2; i++ {
|
||||
// publisher := mq.NewPublisher("publish-1", mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"))
|
||||
err := publisher.Publish(context.Background(), task, "queue1")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
fmt.Println("Async task published successfully")
|
||||
}
|
@@ -1,604 +1,57 @@
|
||||
// fast_http_router.go
|
||||
// Ultra-high performance HTTP router in Go matching gofiber speed
|
||||
// Key optimizations:
|
||||
// - Zero allocations on hot path (no slice/map allocations per request)
|
||||
// - Byte-based routing for maximum speed
|
||||
// - Pre-allocated pools for everything
|
||||
// - Minimal interface overhead
|
||||
// - Direct memory operations where possible
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/mq"
|
||||
"github.com/oarkflow/mq/examples/tasks"
|
||||
)
|
||||
|
||||
// ----------------------------
|
||||
// Public Interfaces (minimal overhead)
|
||||
// ----------------------------
|
||||
|
||||
type HandlerFunc func(*Ctx) error
|
||||
|
||||
type Engine interface {
|
||||
http.Handler
|
||||
Group(prefix string, m ...HandlerFunc) RouteGroup
|
||||
Use(m ...HandlerFunc)
|
||||
GET(path string, h HandlerFunc)
|
||||
POST(path string, h HandlerFunc)
|
||||
PUT(path string, h HandlerFunc)
|
||||
DELETE(path string, h HandlerFunc)
|
||||
Static(prefix, root string)
|
||||
ListenAndServe(addr string) error
|
||||
Shutdown(ctx context.Context) error
|
||||
func main() {
|
||||
b := mq.NewBroker(
|
||||
mq.WithCallback(tasks.Callback),
|
||||
mq.WithBrokerURL(":8081"),
|
||||
mq.WithSecurity(true),
|
||||
// mq.WithMonitoring(true),
|
||||
// mq.WithAdminAddr(":8080"),
|
||||
// mq.WithMetricsAddr(":9090"),
|
||||
)
|
||||
InitializeDefaults(b.SecurityManager())
|
||||
// b := mq.NewBroker(mq.WithCallback(tasks.Callback), mq.WithTLS(true, "./certs/server.crt", "./certs/server.key"), mq.WithCAPath("./certs/ca.cert"))
|
||||
if err := b.InitializeSecurity(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b.NewQueue("queue1")
|
||||
b.NewQueue("queue2")
|
||||
b.StartEnhanced(context.Background())
|
||||
}
|
||||
|
||||
type RouteGroup interface {
|
||||
Use(m ...HandlerFunc)
|
||||
GET(path string, h HandlerFunc)
|
||||
POST(path string, h HandlerFunc)
|
||||
PUT(path string, h HandlerFunc)
|
||||
DELETE(path string, h HandlerFunc)
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// Ultra-fast param extraction
|
||||
// ----------------------------
|
||||
|
||||
type Param struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
// Pre-allocated param slices to avoid any allocations
|
||||
var paramPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make([]Param, 0, 16)
|
||||
},
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// Context with zero allocations
|
||||
// ----------------------------
|
||||
|
||||
type Ctx struct {
|
||||
W http.ResponseWriter
|
||||
Req *http.Request
|
||||
params []Param
|
||||
index int8
|
||||
plen int8
|
||||
|
||||
// Embedded handler chain (no slice allocation)
|
||||
handlers [16]HandlerFunc // fixed size, 99% of routes have < 16 handlers
|
||||
hlen int8
|
||||
|
||||
status int
|
||||
engine *engine
|
||||
}
|
||||
|
||||
var ctxPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &Ctx{}
|
||||
},
|
||||
}
|
||||
|
||||
func (c *Ctx) reset() {
|
||||
c.W = nil
|
||||
c.Req = nil
|
||||
if c.params != nil {
|
||||
paramPool.Put(c.params[:0])
|
||||
c.params = nil
|
||||
// InitializeDefaults adds default permissions, roles, and users for development/testing
|
||||
func InitializeDefaults(sm *mq.SecurityManager) error {
|
||||
permissions := []*mq.Permission{
|
||||
{Name: "task.publish", Resource: "task", Action: "publish", Description: "Publish tasks to queues", CreatedAt: time.Now()},
|
||||
{Name: "task.consume", Resource: "task", Action: "consume", Description: "Consume tasks from queues", CreatedAt: time.Now()},
|
||||
{Name: "queue.manage", Resource: "queue", Action: "manage", Description: "Manage queues", CreatedAt: time.Now()},
|
||||
{Name: "admin.system", Resource: "system", Action: "admin", Description: "System administration", CreatedAt: time.Now()},
|
||||
}
|
||||
c.index = 0
|
||||
c.plen = 0
|
||||
c.hlen = 0
|
||||
c.status = 0
|
||||
c.engine = nil
|
||||
}
|
||||
|
||||
// Ultra-fast param lookup (linear search is faster than map for < 8 params)
|
||||
func (c *Ctx) Param(key string) string {
|
||||
for i := int8(0); i < c.plen; i++ {
|
||||
if c.params[i].Key == key {
|
||||
return c.params[i].Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Ctx) addParam(key, value string) {
|
||||
if c.params == nil {
|
||||
c.params = paramPool.Get().([]Param)
|
||||
}
|
||||
if c.plen < 16 { // max 16 params
|
||||
c.params = append(c.params, Param{Key: key, Value: value})
|
||||
c.plen++
|
||||
}
|
||||
}
|
||||
|
||||
// Zero-allocation header operations
|
||||
func (c *Ctx) Set(key, val string) {
|
||||
if c.W != nil {
|
||||
c.W.Header().Set(key, val)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Ctx) Get(key string) string {
|
||||
if c.Req != nil {
|
||||
return c.Req.Header.Get(key)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Ultra-fast response methods
|
||||
func (c *Ctx) SendString(s string) error {
|
||||
if c.status != 0 {
|
||||
c.W.WriteHeader(c.status)
|
||||
}
|
||||
_, err := io.WriteString(c.W, s)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Ctx) JSON(v any) error {
|
||||
c.Set("Content-Type", "application/json")
|
||||
if c.status != 0 {
|
||||
c.W.WriteHeader(c.status)
|
||||
}
|
||||
return json.NewEncoder(c.W).Encode(v)
|
||||
}
|
||||
|
||||
func (c *Ctx) Status(code int) { c.status = code }
|
||||
|
||||
func (c *Ctx) Next() error {
|
||||
for c.index < c.hlen {
|
||||
h := c.handlers[c.index]
|
||||
c.index++
|
||||
if err := h(c); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// Ultra-fast byte-based router
|
||||
// ----------------------------
|
||||
|
||||
type methodType uint8
|
||||
|
||||
const (
|
||||
methodGet methodType = iota
|
||||
methodPost
|
||||
methodPut
|
||||
methodDelete
|
||||
methodOptions
|
||||
methodHead
|
||||
methodPatch
|
||||
)
|
||||
|
||||
var methodMap = map[string]methodType{
|
||||
"GET": methodGet,
|
||||
"POST": methodPost,
|
||||
"PUT": methodPut,
|
||||
"DELETE": methodDelete,
|
||||
"OPTIONS": methodOptions,
|
||||
"HEAD": methodHead,
|
||||
"PATCH": methodPatch,
|
||||
}
|
||||
|
||||
// Route info with pre-computed handler chain
|
||||
type route struct {
|
||||
handlers [16]HandlerFunc
|
||||
hlen int8
|
||||
}
|
||||
|
||||
// Ultra-fast trie node
|
||||
type node struct {
|
||||
// Static children - direct byte lookup for first character
|
||||
static [256]*node
|
||||
|
||||
// Dynamic children
|
||||
param *node
|
||||
wildcard *node
|
||||
|
||||
// Route data
|
||||
routes [8]*route // index by method type
|
||||
|
||||
// Node metadata
|
||||
paramName string
|
||||
isEnd bool
|
||||
}
|
||||
|
||||
// Path parsing with zero allocations
|
||||
func splitPathFast(path string) []string {
|
||||
if path == "/" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count segments first
|
||||
count := 0
|
||||
start := 1 // skip leading /
|
||||
for i := start; i < len(path); i++ {
|
||||
if path[i] == '/' {
|
||||
count++
|
||||
}
|
||||
}
|
||||
count++ // last segment
|
||||
|
||||
// Pre-allocate exact size
|
||||
segments := make([]string, 0, count)
|
||||
start = 1
|
||||
for i := 1; i <= len(path); i++ {
|
||||
if i == len(path) || path[i] == '/' {
|
||||
if i > start {
|
||||
segments = append(segments, path[start:i])
|
||||
}
|
||||
start = i + 1
|
||||
}
|
||||
}
|
||||
return segments
|
||||
}
|
||||
|
||||
// Add route with minimal allocations
|
||||
func (n *node) addRoute(method methodType, segments []string, handlers []HandlerFunc) {
|
||||
curr := n
|
||||
|
||||
for _, seg := range segments {
|
||||
if len(seg) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if seg[0] == ':' {
|
||||
// Parameter route
|
||||
if curr.param == nil {
|
||||
curr.param = &node{paramName: seg[1:]}
|
||||
}
|
||||
curr = curr.param
|
||||
} else if seg[0] == '*' {
|
||||
// Wildcard route
|
||||
if curr.wildcard == nil {
|
||||
curr.wildcard = &node{paramName: seg[1:]}
|
||||
}
|
||||
curr = curr.wildcard
|
||||
break // wildcard consumes rest
|
||||
} else {
|
||||
// Static route - use first byte for O(1) lookup
|
||||
firstByte := seg[0]
|
||||
if curr.static[firstByte] == nil {
|
||||
curr.static[firstByte] = &node{}
|
||||
}
|
||||
curr = curr.static[firstByte]
|
||||
}
|
||||
}
|
||||
|
||||
curr.isEnd = true
|
||||
|
||||
// Store pre-computed handler chain
|
||||
if curr.routes[method] == nil {
|
||||
curr.routes[method] = &route{}
|
||||
}
|
||||
|
||||
r := curr.routes[method]
|
||||
r.hlen = 0
|
||||
for i, h := range handlers {
|
||||
if i >= 16 {
|
||||
break // max 16 handlers
|
||||
}
|
||||
r.handlers[i] = h
|
||||
r.hlen++
|
||||
}
|
||||
}
|
||||
|
||||
// Ultra-fast route matching
|
||||
func (n *node) match(segments []string, params []Param, plen *int8) (*route, methodType, bool) {
|
||||
curr := n
|
||||
|
||||
for i, seg := range segments {
|
||||
if len(seg) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try static first (O(1) lookup)
|
||||
firstByte := seg[0]
|
||||
if next := curr.static[firstByte]; next != nil {
|
||||
curr = next
|
||||
continue
|
||||
}
|
||||
|
||||
// Try parameter
|
||||
if curr.param != nil {
|
||||
if *plen < 16 {
|
||||
params[*plen] = Param{Key: curr.param.paramName, Value: seg}
|
||||
(*plen)++
|
||||
}
|
||||
curr = curr.param
|
||||
continue
|
||||
}
|
||||
|
||||
// Try wildcard
|
||||
if curr.wildcard != nil {
|
||||
if *plen < 16 {
|
||||
// Wildcard captures remaining path
|
||||
remaining := strings.Join(segments[i:], "/")
|
||||
params[*plen] = Param{Key: curr.wildcard.paramName, Value: remaining}
|
||||
(*plen)++
|
||||
}
|
||||
curr = curr.wildcard
|
||||
break
|
||||
}
|
||||
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
if !curr.isEnd {
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
// Find method (most common methods first)
|
||||
if r := curr.routes[methodGet]; r != nil {
|
||||
return r, methodGet, true
|
||||
}
|
||||
if r := curr.routes[methodPost]; r != nil {
|
||||
return r, methodPost, true
|
||||
}
|
||||
if r := curr.routes[methodPut]; r != nil {
|
||||
return r, methodPut, true
|
||||
}
|
||||
if r := curr.routes[methodDelete]; r != nil {
|
||||
return r, methodDelete, true
|
||||
}
|
||||
|
||||
return nil, 0, false
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// Engine implementation
|
||||
// ----------------------------
|
||||
|
||||
type engine struct {
|
||||
tree *node
|
||||
middleware []HandlerFunc
|
||||
servers []*http.Server
|
||||
shutdown int32
|
||||
}
|
||||
|
||||
func New() Engine {
|
||||
return &engine{
|
||||
tree: &node{},
|
||||
}
|
||||
}
|
||||
|
||||
// Ultra-fast request handling
|
||||
func (e *engine) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if atomic.LoadInt32(&e.shutdown) == 1 {
|
||||
w.WriteHeader(503)
|
||||
return
|
||||
}
|
||||
|
||||
// Get context from pool
|
||||
c := ctxPool.Get().(*Ctx)
|
||||
c.reset()
|
||||
c.W = w
|
||||
c.Req = r
|
||||
c.engine = e
|
||||
|
||||
// Parse path once
|
||||
segments := splitPathFast(r.URL.Path)
|
||||
|
||||
// Pre-allocated param array (on stack)
|
||||
var paramArray [16]Param
|
||||
var plen int8
|
||||
|
||||
// Match route
|
||||
route, _, found := e.tree.match(segments, paramArray[:], &plen)
|
||||
|
||||
if !found {
|
||||
w.WriteHeader(404)
|
||||
w.Write([]byte("404"))
|
||||
ctxPool.Put(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Set params (no allocation)
|
||||
if plen > 0 {
|
||||
c.params = paramPool.Get().([]Param)
|
||||
for i := int8(0); i < plen; i++ {
|
||||
c.params = append(c.params, paramArray[i])
|
||||
}
|
||||
c.plen = plen
|
||||
}
|
||||
|
||||
// Copy handlers (no allocation - fixed array)
|
||||
copy(c.handlers[:], route.handlers[:route.hlen])
|
||||
c.hlen = route.hlen
|
||||
|
||||
// Execute
|
||||
if err := c.Next(); err != nil {
|
||||
w.WriteHeader(500)
|
||||
}
|
||||
|
||||
ctxPool.Put(c)
|
||||
}
|
||||
|
||||
func (e *engine) Use(m ...HandlerFunc) {
|
||||
e.middleware = append(e.middleware, m...)
|
||||
}
|
||||
|
||||
func (e *engine) addRoute(method, path string, groupMiddleware []HandlerFunc, h HandlerFunc) {
|
||||
mt, ok := methodMap[method]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
segments := splitPathFast(path)
|
||||
|
||||
// Build handler chain: global + group + route
|
||||
totalLen := len(e.middleware) + len(groupMiddleware) + 1
|
||||
if totalLen > 16 {
|
||||
totalLen = 16 // max handlers
|
||||
}
|
||||
|
||||
handlers := make([]HandlerFunc, 0, totalLen)
|
||||
handlers = append(handlers, e.middleware...)
|
||||
handlers = append(handlers, groupMiddleware...)
|
||||
handlers = append(handlers, h)
|
||||
|
||||
e.tree.addRoute(mt, segments, handlers)
|
||||
}
|
||||
|
||||
func (e *engine) GET(path string, h HandlerFunc) { e.addRoute("GET", path, nil, h) }
|
||||
func (e *engine) POST(path string, h HandlerFunc) { e.addRoute("POST", path, nil, h) }
|
||||
func (e *engine) PUT(path string, h HandlerFunc) { e.addRoute("PUT", path, nil, h) }
|
||||
func (e *engine) DELETE(path string, h HandlerFunc) { e.addRoute("DELETE", path, nil, h) }
|
||||
|
||||
// RouteGroup implementation
|
||||
type routeGroup struct {
|
||||
prefix string
|
||||
engine *engine
|
||||
middleware []HandlerFunc
|
||||
}
|
||||
|
||||
func (e *engine) Group(prefix string, m ...HandlerFunc) RouteGroup {
|
||||
return &routeGroup{
|
||||
prefix: prefix,
|
||||
engine: e,
|
||||
middleware: m,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *routeGroup) Use(m ...HandlerFunc) { g.middleware = append(g.middleware, m...) }
|
||||
|
||||
func (g *routeGroup) add(method, path string, h HandlerFunc) {
|
||||
fullPath := g.prefix + path
|
||||
g.engine.addRoute(method, fullPath, g.middleware, h)
|
||||
}
|
||||
|
||||
func (g *routeGroup) GET(path string, h HandlerFunc) { g.add("GET", path, h) }
|
||||
func (g *routeGroup) POST(path string, h HandlerFunc) { g.add("POST", path, h) }
|
||||
func (g *routeGroup) PUT(path string, h HandlerFunc) { g.add("PUT", path, h) }
|
||||
func (g *routeGroup) DELETE(path string, h HandlerFunc) { g.add("DELETE", path, h) }
|
||||
|
||||
// Ultra-fast static file serving
|
||||
func (e *engine) Static(prefix, root string) {
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
prefix += "/"
|
||||
}
|
||||
e.GET(strings.TrimSuffix(prefix, "/"), func(c *Ctx) error {
|
||||
path := root + "/"
|
||||
http.ServeFile(c.W, c.Req, path)
|
||||
return nil
|
||||
})
|
||||
e.GET(prefix+"*", func(c *Ctx) error {
|
||||
filepath := c.Param("")
|
||||
if filepath == "" {
|
||||
filepath = "/"
|
||||
}
|
||||
path := root + "/" + filepath
|
||||
http.ServeFile(c.W, c.Req, path)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (e *engine) ListenAndServe(addr string) error {
|
||||
srv := &http.Server{Addr: addr, Handler: e}
|
||||
e.servers = append(e.servers, srv)
|
||||
return srv.ListenAndServe()
|
||||
}
|
||||
|
||||
func (e *engine) Shutdown(ctx context.Context) error {
|
||||
atomic.StoreInt32(&e.shutdown, 1)
|
||||
for _, srv := range e.servers {
|
||||
srv.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// Middleware
|
||||
// ----------------------------
|
||||
|
||||
func Recover() HandlerFunc {
|
||||
return func(c *Ctx) error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Printf("panic: %v", r)
|
||||
c.Status(500)
|
||||
c.SendString("Internal Server Error")
|
||||
}
|
||||
}()
|
||||
return c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func Logger() HandlerFunc {
|
||||
return func(c *Ctx) error {
|
||||
start := time.Now()
|
||||
err := c.Next()
|
||||
log.Printf("%s %s %v", c.Req.Method, c.Req.URL.Path, time.Since(start))
|
||||
err := sm.AddPermissions(permissions...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// Example
|
||||
// ----------------------------
|
||||
|
||||
func mai3n() {
|
||||
app := New()
|
||||
app.Use(Recover())
|
||||
|
||||
app.GET("/", func(c *Ctx) error {
|
||||
return c.SendString("Hello World!")
|
||||
})
|
||||
|
||||
app.GET("/user/:id", func(c *Ctx) error {
|
||||
return c.SendString("User: " + c.Param("id"))
|
||||
})
|
||||
|
||||
api := app.Group("/api")
|
||||
api.GET("/ping", func(c *Ctx) error {
|
||||
return c.JSON(map[string]any{"message": "pong"})
|
||||
})
|
||||
|
||||
app.Static("/static", "public")
|
||||
|
||||
fmt.Println("Server starting on :8080")
|
||||
if err := app.ListenAndServe(":8080"); err != nil {
|
||||
log.Fatal(err)
|
||||
roles := []*mq.Role{
|
||||
{Name: "publisher", Description: "Can publish tasks", Permissions: []string{"task.publish"}, CreatedAt: time.Now()},
|
||||
{Name: "consumer", Description: "Can consume tasks", Permissions: []string{"task.consume"}, CreatedAt: time.Now()},
|
||||
{Name: "admin", Description: "Full system access", Permissions: []string{"task.publish", "task.consume", "queue.manage", "admin.system"}, CreatedAt: time.Now()},
|
||||
}
|
||||
err = sm.AddRoles(roles...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
users := []*mq.User{
|
||||
{ID: "admin", Username: "admin", Roles: []string{"admin"}, CreatedAt: time.Now(), Password: "admin123"},
|
||||
{ID: "publisher", Username: "publisher", Roles: []string{"publisher"}, CreatedAt: time.Now(), Password: "pub123"},
|
||||
{ID: "consumer", Username: "consumer", Roles: []string{"consumer"}, CreatedAt: time.Now(), Password: "con123"},
|
||||
}
|
||||
return sm.AddUsers(users...)
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// Performance optimizations:
|
||||
// ----------------------------
|
||||
// 1. Zero allocations on hot path:
|
||||
// - Fixed-size arrays instead of slices for handlers/params
|
||||
// - Stack-allocated param arrays
|
||||
// - Byte-based trie with O(1) static lookups
|
||||
// - Pre-allocated pools for everything
|
||||
//
|
||||
// 2. Minimal interface overhead:
|
||||
// - Direct memory operations
|
||||
// - Embedded handler chains in context
|
||||
// - Method type enum instead of string comparisons
|
||||
//
|
||||
// 3. Optimized data structures:
|
||||
// - 256-element array for O(1) first-byte lookup
|
||||
// - Linear search for params (faster than map for < 8 items)
|
||||
// - Pre-computed route chains stored in trie
|
||||
//
|
||||
// 4. Fast path parsing:
|
||||
// - Single-pass path splitting
|
||||
// - Zero-allocation string operations
|
||||
// - Minimal string comparisons
|
||||
//
|
||||
// This implementation should now match gofiber's performance by using
|
||||
// similar zero-allocation techniques and optimized data structures.
|
||||
|
187
mq.go
187
mq.go
@@ -280,6 +280,12 @@ type Options struct {
|
||||
BrokerRateLimiter *RateLimiter // new field for broker rate limiting
|
||||
ConsumerRateLimiter *RateLimiter // new field for consumer rate limiting
|
||||
consumerTimeout time.Duration // timeout for consumer message processing (0 = no timeout)
|
||||
adminAddr string // address for admin server
|
||||
metricsAddr string // address for metrics server
|
||||
enableSecurity bool // enable security features
|
||||
enableMonitoring bool // enable monitoring features
|
||||
username string // username for authentication
|
||||
password string // password for authentication
|
||||
}
|
||||
|
||||
func (o *Options) SetSyncMode(sync bool) {
|
||||
@@ -467,15 +473,19 @@ type Broker struct {
|
||||
listener net.Listener
|
||||
|
||||
// Enhanced production features
|
||||
connectionPool *ConnectionPool
|
||||
healthChecker *HealthChecker
|
||||
circuitBreaker *EnhancedCircuitBreaker
|
||||
metricsCollector *MetricsCollector
|
||||
messageStore MessageStore
|
||||
isShutdown int32
|
||||
shutdown chan struct{}
|
||||
wg sync.WaitGroup
|
||||
logger logger.Logger
|
||||
connectionPool *ConnectionPool
|
||||
healthChecker *HealthChecker
|
||||
circuitBreaker *EnhancedCircuitBreaker
|
||||
metricsCollector *MetricsCollector
|
||||
messageStore MessageStore
|
||||
securityManager *SecurityManager
|
||||
adminServer *AdminServer
|
||||
metricsServer *MetricsServer
|
||||
authenticatedConns storage.IMap[string, bool] // authenticated connections
|
||||
isShutdown int32
|
||||
shutdown chan struct{}
|
||||
wg sync.WaitGroup
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewBroker(opts ...Option) *Broker {
|
||||
@@ -491,13 +501,38 @@ func NewBroker(opts ...Option) *Broker {
|
||||
opts: options,
|
||||
|
||||
// Enhanced production features
|
||||
connectionPool: NewConnectionPool(1000), // max 1000 connections
|
||||
healthChecker: NewHealthChecker(),
|
||||
circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout
|
||||
metricsCollector: NewMetricsCollector(),
|
||||
messageStore: NewInMemoryMessageStore(),
|
||||
shutdown: make(chan struct{}),
|
||||
logger: options.Logger(),
|
||||
connectionPool: NewConnectionPool(1000), // max 1000 connections
|
||||
healthChecker: NewHealthChecker(),
|
||||
circuitBreaker: NewEnhancedCircuitBreaker(10, 30*time.Second), // 10 failures, 30s timeout
|
||||
metricsCollector: NewMetricsCollector(),
|
||||
messageStore: NewInMemoryMessageStore(),
|
||||
authenticatedConns: memory.New[string, bool](),
|
||||
shutdown: make(chan struct{}),
|
||||
logger: options.Logger(),
|
||||
}
|
||||
|
||||
if options.enableSecurity {
|
||||
broker.securityManager = NewSecurityManager()
|
||||
}
|
||||
if options.enableMonitoring {
|
||||
if options.adminAddr != "" {
|
||||
broker.adminServer = NewAdminServer(broker, options.adminAddr, options.Logger())
|
||||
}
|
||||
if options.metricsAddr != "" {
|
||||
// Need to create MonitoringConfig, use default
|
||||
config := &MonitoringConfig{
|
||||
EnableMetrics: true,
|
||||
MetricsPort: 9090, // default
|
||||
MetricsPath: "/metrics",
|
||||
EnableHealthCheck: true,
|
||||
HealthCheckPort: 8080,
|
||||
HealthCheckPath: "/health",
|
||||
HealthCheckInterval: time.Minute,
|
||||
EnableLogging: true,
|
||||
LogLevel: "info",
|
||||
}
|
||||
broker.metricsServer = NewMetricsServer(broker, config, options.Logger())
|
||||
}
|
||||
}
|
||||
|
||||
broker.healthChecker.broker = broker
|
||||
@@ -508,6 +543,18 @@ func (b *Broker) Options() *Options {
|
||||
return b.opts
|
||||
}
|
||||
|
||||
func (b *Broker) SecurityManager() *SecurityManager {
|
||||
return b.securityManager
|
||||
}
|
||||
|
||||
// InitializeSecurity initializes default users, roles, and permissions for development/testing
|
||||
func (b *Broker) InitializeSecurity() error {
|
||||
if b.securityManager == nil {
|
||||
return fmt.Errorf("security manager not initialized")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
|
||||
consumerID, ok := GetConsumerID(ctx)
|
||||
if ok && consumerID != "" {
|
||||
@@ -553,6 +600,10 @@ func (b *Broker) OnClose(ctx context.Context, conn net.Conn) error {
|
||||
b.publishers.Del(publisherID)
|
||||
}
|
||||
}
|
||||
// Remove from authenticated connections
|
||||
connID := conn.RemoteAddr().String()
|
||||
b.authenticatedConns.Del(connID)
|
||||
|
||||
log.Printf("BROKER - Connection closed: address %s", conn.RemoteAddr())
|
||||
return nil
|
||||
}
|
||||
@@ -563,8 +614,29 @@ func (b *Broker) OnError(_ context.Context, conn net.Conn, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) isAuthenticated(connID string) bool {
|
||||
if b.securityManager == nil {
|
||||
return true // no security, allow all
|
||||
}
|
||||
_, ok := b.authenticatedConns.Get(connID)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Conn) {
|
||||
connID := conn.RemoteAddr().String()
|
||||
|
||||
// Check authentication for protected commands
|
||||
if b.securityManager != nil && (msg.Command == consts.PUBLISH || msg.Command == consts.SUBSCRIBE) {
|
||||
if !b.isAuthenticated(connID) {
|
||||
b.logger.Warn("Unauthenticated access attempt", logger.Field{Key: "command", Value: msg.Command.String()}, logger.Field{Key: "conn", Value: connID})
|
||||
// Send error response
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
switch msg.Command {
|
||||
case consts.AUTH:
|
||||
b.AuthHandler(ctx, conn, msg)
|
||||
case consts.PUBLISH:
|
||||
b.PublishHandler(ctx, conn, msg)
|
||||
case consts.SUBSCRIBE:
|
||||
@@ -588,6 +660,54 @@ func (b *Broker) OnMessage(ctx context.Context, msg *codec.Message, conn net.Con
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) AuthHandler(ctx context.Context, conn net.Conn, msg *codec.Message) {
|
||||
connID := conn.RemoteAddr().String()
|
||||
|
||||
// Parse auth credentials from payload
|
||||
var authReq map[string]any
|
||||
if err := json.Unmarshal(msg.Payload, &authReq); err != nil {
|
||||
b.logger.Error("Invalid auth request", logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
username, _ := authReq["username"].(string)
|
||||
password, _ := authReq["password"].(string)
|
||||
|
||||
// Authenticate
|
||||
user, err := b.securityManager.Authenticate(ctx, map[string]any{
|
||||
"username": username,
|
||||
"password": password,
|
||||
})
|
||||
if err != nil {
|
||||
b.logger.Warn("Authentication failed", logger.Field{Key: "username", Value: username}, logger.Field{Key: "conn", Value: connID})
|
||||
// Send AUTH_DENY
|
||||
denyMsg, err := codec.NewMessage(consts.AUTH_DENY, []byte(fmt.Sprintf(`{"error":"%s"}`, err.Error())), "", msg.Headers)
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to create AUTH_DENY message", logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := b.send(ctx, conn, denyMsg); err != nil {
|
||||
b.logger.Error("Failed to send AUTH_DENY", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Mark as authenticated
|
||||
b.authenticatedConns.Set(connID, true)
|
||||
|
||||
// Send AUTH_ACK
|
||||
ackMsg, err := codec.NewMessage(consts.AUTH_ACK, []byte(`{"status":"authenticated"}`), "", msg.Headers)
|
||||
if err != nil {
|
||||
b.logger.Error("Failed to create AUTH_ACK message", logger.Field{Key: "error", Value: err.Error()})
|
||||
return
|
||||
}
|
||||
if err := b.send(ctx, conn, ackMsg); err != nil {
|
||||
b.logger.Error("Failed to send AUTH_ACK", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
|
||||
b.logger.Info("User authenticated", logger.Field{Key: "username", Value: user.Username}, logger.Field{Key: "conn", Value: connID})
|
||||
}
|
||||
|
||||
func (b *Broker) AdjustConsumerWorkers(noOfWorkers int, consumerID ...string) {
|
||||
b.consumers.ForEach(func(_ string, c *consumer) bool {
|
||||
return true
|
||||
@@ -1498,10 +1618,43 @@ func (b *Broker) StartEnhanced(ctx context.Context) error {
|
||||
b.wg.Add(1)
|
||||
go b.messageStoreCleanupRoutine()
|
||||
|
||||
// Start admin server if enabled
|
||||
if b.adminServer != nil {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if err := b.adminServer.Start(); err != nil {
|
||||
b.logger.Error("Failed to start admin server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Start metrics server if enabled
|
||||
if b.metricsServer != nil {
|
||||
b.wg.Add(1)
|
||||
go func() {
|
||||
defer b.wg.Done()
|
||||
if err := b.metricsServer.Start(ctx); err != nil {
|
||||
b.logger.Error("Failed to start metrics server", logger.Field{Key: "error", Value: err.Error()})
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
b.logger.Info("Enhanced broker starting with production features enabled")
|
||||
|
||||
// Start the enhanced broker with its own implementation
|
||||
return b.startEnhancedBroker(ctx)
|
||||
if err := b.startEnhancedBroker(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for shutdown signal
|
||||
<-b.shutdown
|
||||
b.logger.Info("Enhanced broker shutting down")
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
b.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startEnhancedBroker starts the core broker functionality
|
||||
|
36
options.go
36
options.go
@@ -314,3 +314,39 @@ func WithConsumerTimeout(timeout time.Duration) Option {
|
||||
opts.consumerTimeout = timeout
|
||||
}
|
||||
}
|
||||
|
||||
func WithAdminAddr(addr string) Option {
|
||||
return func(opts *Options) {
|
||||
opts.adminAddr = addr
|
||||
}
|
||||
}
|
||||
|
||||
func WithMetricsAddr(addr string) Option {
|
||||
return func(opts *Options) {
|
||||
opts.metricsAddr = addr
|
||||
}
|
||||
}
|
||||
|
||||
func WithSecurity(enabled bool) Option {
|
||||
return func(opts *Options) {
|
||||
opts.enableSecurity = enabled
|
||||
}
|
||||
}
|
||||
|
||||
func WithMonitoring(enabled bool) Option {
|
||||
return func(opts *Options) {
|
||||
opts.enableMonitoring = enabled
|
||||
}
|
||||
}
|
||||
|
||||
func WithUsername(username string) Option {
|
||||
return func(opts *Options) {
|
||||
opts.username = username
|
||||
}
|
||||
}
|
||||
|
||||
func WithPassword(password string) Option {
|
||||
return func(opts *Options) {
|
||||
opts.password = password
|
||||
}
|
||||
}
|
||||
|
67
publisher.go
67
publisher.go
@@ -18,10 +18,11 @@ import (
|
||||
)
|
||||
|
||||
type Publisher struct {
|
||||
opts *Options
|
||||
id string
|
||||
conn net.Conn
|
||||
connLock sync.Mutex
|
||||
opts *Options
|
||||
id string
|
||||
conn net.Conn
|
||||
connLock sync.Mutex
|
||||
authenticated bool
|
||||
}
|
||||
|
||||
func NewPublisher(id string, opts ...Option) *Publisher {
|
||||
@@ -60,12 +61,70 @@ func (p *Publisher) ensureConnection(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to connect to broker after retries: %w", err)
|
||||
}
|
||||
|
||||
// Auth authenticates the publisher with the broker
|
||||
func (p *Publisher) Auth(ctx context.Context, username, password string) error {
|
||||
if err := p.ensureConnection(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.connLock.Lock()
|
||||
conn := p.conn
|
||||
p.connLock.Unlock()
|
||||
|
||||
authPayload := map[string]string{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
payload, err := json.Marshal(authPayload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
headers := map[string]string{
|
||||
consts.PublisherKey: p.id,
|
||||
consts.ContentType: consts.TypeJson,
|
||||
}
|
||||
msg, err := codec.NewMessage(consts.AUTH, payload, "", headers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = codec.SendMessage(ctx, conn, msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for AUTH_ACK
|
||||
resp, err := codec.ReadMessage(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Command != consts.AUTH_ACK {
|
||||
return fmt.Errorf("authentication failed: %s", string(resp.Payload))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish method that uses the persistent connection.
|
||||
func (p *Publisher) Publish(ctx context.Context, task Task, queue string) error {
|
||||
// Ensure connection is established.
|
||||
if err := p.ensureConnection(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Authenticate if security is enabled and not already authenticated
|
||||
if p.opts.enableSecurity && !p.authenticated {
|
||||
if p.opts.username == "" || p.opts.password == "" {
|
||||
return fmt.Errorf("username and password required for authentication")
|
||||
}
|
||||
if err := p.Auth(ctx, p.opts.username, p.opts.password); err != nil {
|
||||
return fmt.Errorf("authentication failed: %w", err)
|
||||
}
|
||||
p.authenticated = true
|
||||
}
|
||||
|
||||
delay := p.opts.initialDelay
|
||||
for i := 0; i < p.opts.maxRetries; i++ {
|
||||
// Use the persistent connection.
|
||||
|
506
security.go
506
security.go
@@ -10,8 +10,318 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oarkflow/squealx"
|
||||
)
|
||||
|
||||
// Storage interfaces for persistence
|
||||
type UserStorage interface {
|
||||
GetUser(username string) (*User, error)
|
||||
GetPassword(username string) (string, error)
|
||||
SaveUser(user *User, password string) error
|
||||
ListUsers() ([]*User, error)
|
||||
DeleteUser(username string) error
|
||||
}
|
||||
|
||||
type RoleStorage interface {
|
||||
GetRole(name string) (*Role, error)
|
||||
SaveRole(role *Role) error
|
||||
ListRoles() ([]*Role, error)
|
||||
DeleteRole(name string) error
|
||||
}
|
||||
|
||||
type PermissionStorage interface {
|
||||
GetPermission(name string) (*Permission, error)
|
||||
SavePermission(perm *Permission) error
|
||||
ListPermissions() ([]*Permission, error)
|
||||
DeletePermission(name string) error
|
||||
}
|
||||
|
||||
// MemoryUserStorage implements UserStorage using in-memory storage
|
||||
type MemoryUserStorage struct {
|
||||
users map[string]*User
|
||||
passwords map[string]string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMemoryUserStorage() *MemoryUserStorage {
|
||||
return &MemoryUserStorage{
|
||||
users: make(map[string]*User),
|
||||
passwords: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (mus *MemoryUserStorage) GetUser(username string) (*User, error) {
|
||||
mus.mu.RLock()
|
||||
defer mus.mu.RUnlock()
|
||||
user, exists := mus.users[username]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (mus *MemoryUserStorage) SaveUser(user *User, password string) error {
|
||||
mus.mu.Lock()
|
||||
defer mus.mu.Unlock()
|
||||
mus.users[user.Username] = user
|
||||
mus.passwords[user.Username] = password
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mus *MemoryUserStorage) ListUsers() ([]*User, error) {
|
||||
mus.mu.RLock()
|
||||
defer mus.mu.RUnlock()
|
||||
users := make([]*User, 0, len(mus.users))
|
||||
for _, user := range mus.users {
|
||||
users = append(users, user)
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (mus *MemoryUserStorage) GetPassword(username string) (string, error) {
|
||||
mus.mu.RLock()
|
||||
defer mus.mu.RUnlock()
|
||||
password, exists := mus.passwords[username]
|
||||
if !exists {
|
||||
return "", fmt.Errorf("password not found")
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (mus *MemoryUserStorage) DeleteUser(username string) error {
|
||||
mus.mu.Lock()
|
||||
defer mus.mu.Unlock()
|
||||
delete(mus.users, username)
|
||||
delete(mus.passwords, username)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MemoryRoleStorage implements RoleStorage using in-memory storage
|
||||
type MemoryRoleStorage struct {
|
||||
roles map[string]*Role
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMemoryRoleStorage() *MemoryRoleStorage {
|
||||
return &MemoryRoleStorage{
|
||||
roles: make(map[string]*Role),
|
||||
}
|
||||
}
|
||||
|
||||
func (mrs *MemoryRoleStorage) GetRole(name string) (*Role, error) {
|
||||
mrs.mu.RLock()
|
||||
defer mrs.mu.RUnlock()
|
||||
role, exists := mrs.roles[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("role not found")
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (mrs *MemoryRoleStorage) SaveRole(role *Role) error {
|
||||
mrs.mu.Lock()
|
||||
defer mrs.mu.Unlock()
|
||||
mrs.roles[role.Name] = role
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mrs *MemoryRoleStorage) ListRoles() ([]*Role, error) {
|
||||
mrs.mu.RLock()
|
||||
defer mrs.mu.RUnlock()
|
||||
roles := make([]*Role, 0, len(mrs.roles))
|
||||
for _, role := range mrs.roles {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
func (mrs *MemoryRoleStorage) DeleteRole(name string) error {
|
||||
mrs.mu.Lock()
|
||||
defer mrs.mu.Unlock()
|
||||
delete(mrs.roles, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// MemoryPermissionStorage implements PermissionStorage using in-memory storage
|
||||
type MemoryPermissionStorage struct {
|
||||
permissions map[string]*Permission
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMemoryPermissionStorage() *MemoryPermissionStorage {
|
||||
return &MemoryPermissionStorage{
|
||||
permissions: make(map[string]*Permission),
|
||||
}
|
||||
}
|
||||
|
||||
func (mps *MemoryPermissionStorage) GetPermission(name string) (*Permission, error) {
|
||||
mps.mu.RLock()
|
||||
defer mps.mu.RUnlock()
|
||||
perm, exists := mps.permissions[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("permission not found")
|
||||
}
|
||||
return perm, nil
|
||||
}
|
||||
|
||||
func (mps *MemoryPermissionStorage) SavePermission(perm *Permission) error {
|
||||
mps.mu.Lock()
|
||||
defer mps.mu.Unlock()
|
||||
mps.permissions[perm.Name] = perm
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mps *MemoryPermissionStorage) ListPermissions() ([]*Permission, error) {
|
||||
mps.mu.RLock()
|
||||
defer mps.mu.RUnlock()
|
||||
perms := make([]*Permission, 0, len(mps.permissions))
|
||||
for _, perm := range mps.permissions {
|
||||
perms = append(perms, perm)
|
||||
}
|
||||
return perms, nil
|
||||
}
|
||||
|
||||
func (mps *MemoryPermissionStorage) DeletePermission(name string) error {
|
||||
mps.mu.Lock()
|
||||
defer mps.mu.Unlock()
|
||||
delete(mps.permissions, name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SQLUserStorage implements UserStorage using SQL database
|
||||
type SQLUserStorage struct {
|
||||
db *squealx.DB
|
||||
}
|
||||
|
||||
func NewSQLUserStorage(db *squealx.DB) *SQLUserStorage {
|
||||
return &SQLUserStorage{db: db}
|
||||
}
|
||||
|
||||
func (sus *SQLUserStorage) GetUser(username string) (*User, error) {
|
||||
var user User
|
||||
err := sus.db.Get(&user, "SELECT id, username, roles, permissions, metadata, created_at, last_login_at FROM users WHERE username = $1", username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (sus *SQLUserStorage) GetPassword(username string) (string, error) {
|
||||
var password string
|
||||
err := sus.db.Get(&password, "SELECT password FROM users WHERE username = $1", username)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return password, nil
|
||||
}
|
||||
|
||||
func (sus *SQLUserStorage) SaveUser(user *User, password string) error {
|
||||
_, err := sus.db.Exec(`
|
||||
INSERT INTO users (id, username, password, roles, permissions, metadata, created_at, last_login_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
ON CONFLICT (username) DO UPDATE SET
|
||||
password = EXCLUDED.password,
|
||||
roles = EXCLUDED.roles,
|
||||
permissions = EXCLUDED.permissions,
|
||||
metadata = EXCLUDED.metadata,
|
||||
last_login_at = EXCLUDED.last_login_at`,
|
||||
user.ID, user.Username, password, user.Roles, user.Permissions, user.Metadata, user.CreatedAt, user.LastLoginAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (sus *SQLUserStorage) ListUsers() ([]*User, error) {
|
||||
var users []*User
|
||||
err := sus.db.Select(&users, "SELECT id, username, roles, permissions, metadata, created_at, last_login_at FROM users")
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (sus *SQLUserStorage) DeleteUser(username string) error {
|
||||
_, err := sus.db.Exec("DELETE FROM users WHERE username = $1", username)
|
||||
return err
|
||||
}
|
||||
|
||||
// SQLRoleStorage implements RoleStorage using SQL database
|
||||
type SQLRoleStorage struct {
|
||||
db *squealx.DB
|
||||
}
|
||||
|
||||
func NewSQLRoleStorage(db *squealx.DB) *SQLRoleStorage {
|
||||
return &SQLRoleStorage{db: db}
|
||||
}
|
||||
|
||||
func (srs *SQLRoleStorage) GetRole(name string) (*Role, error) {
|
||||
var role Role
|
||||
err := srs.db.Get(&role, "SELECT name, description, permissions, created_at FROM roles WHERE name = $1", name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
func (srs *SQLRoleStorage) SaveRole(role *Role) error {
|
||||
_, err := srs.db.Exec(`
|
||||
INSERT INTO roles (name, description, permissions, created_at)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (name) DO UPDATE SET
|
||||
description = EXCLUDED.description,
|
||||
permissions = EXCLUDED.permissions`,
|
||||
role.Name, role.Description, role.Permissions, role.CreatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (srs *SQLRoleStorage) ListRoles() ([]*Role, error) {
|
||||
var roles []*Role
|
||||
err := srs.db.Select(&roles, "SELECT name, description, permissions, created_at FROM roles")
|
||||
return roles, err
|
||||
}
|
||||
|
||||
func (srs *SQLRoleStorage) DeleteRole(name string) error {
|
||||
_, err := srs.db.Exec("DELETE FROM roles WHERE name = $1", name)
|
||||
return err
|
||||
}
|
||||
|
||||
// SQLPermissionStorage implements PermissionStorage using SQL database
|
||||
type SQLPermissionStorage struct {
|
||||
db *squealx.DB
|
||||
}
|
||||
|
||||
func NewSQLPermissionStorage(db *squealx.DB) *SQLPermissionStorage {
|
||||
return &SQLPermissionStorage{db: db}
|
||||
}
|
||||
|
||||
func (sps *SQLPermissionStorage) GetPermission(name string) (*Permission, error) {
|
||||
var perm Permission
|
||||
err := sps.db.Get(&perm, "SELECT name, resource, action, description, created_at FROM permissions WHERE name = $1", name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &perm, nil
|
||||
}
|
||||
|
||||
func (sps *SQLPermissionStorage) SavePermission(perm *Permission) error {
|
||||
_, err := sps.db.Exec(`
|
||||
INSERT INTO permissions (name, resource, action, description, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (name) DO UPDATE SET
|
||||
resource = EXCLUDED.resource,
|
||||
action = EXCLUDED.action,
|
||||
description = EXCLUDED.description`,
|
||||
perm.Name, perm.Resource, perm.Action, perm.Description, perm.CreatedAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (sps *SQLPermissionStorage) ListPermissions() ([]*Permission, error) {
|
||||
var perms []*Permission
|
||||
err := sps.db.Select(&perms, "SELECT name, resource, action, description, created_at FROM permissions")
|
||||
return perms, err
|
||||
}
|
||||
|
||||
func (sps *SQLPermissionStorage) DeletePermission(name string) error {
|
||||
_, err := sps.db.Exec("DELETE FROM permissions WHERE name = $1", name)
|
||||
return err
|
||||
}
|
||||
|
||||
// SecurityManager handles authentication, authorization, and security policies
|
||||
type SecurityManager struct {
|
||||
authProviders map[string]AuthProvider
|
||||
@@ -34,6 +344,7 @@ type AuthProvider interface {
|
||||
type User struct {
|
||||
ID string `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"-"`
|
||||
Roles []string `json:"roles"`
|
||||
Permissions []string `json:"permissions"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
@@ -43,9 +354,9 @@ type User struct {
|
||||
|
||||
// RoleManager manages user roles and permissions
|
||||
type RoleManager struct {
|
||||
roles map[string]*Role
|
||||
permissions map[string]*Permission
|
||||
mu sync.RWMutex
|
||||
roleStorage RoleStorage
|
||||
permissionStorage PermissionStorage
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Role represents a user role with associated permissions
|
||||
@@ -123,93 +434,47 @@ func NewSecurityManager() *SecurityManager {
|
||||
key := make([]byte, 32)
|
||||
rand.Read(key)
|
||||
|
||||
return &SecurityManager{
|
||||
// Create memory storages by default
|
||||
userStorage := NewMemoryUserStorage()
|
||||
roleStorage := NewMemoryRoleStorage()
|
||||
permissionStorage := NewMemoryPermissionStorage()
|
||||
|
||||
sm := &SecurityManager{
|
||||
authProviders: make(map[string]AuthProvider),
|
||||
roleManager: NewRoleManager(),
|
||||
roleManager: NewRoleManager(roleStorage, permissionStorage),
|
||||
rateLimiter: NewSecurityRateLimiter(5, time.Minute*15), // 5 attempts per 15 minutes
|
||||
auditLogger: NewAuditLogger(10000),
|
||||
sessionManager: NewSessionManager(time.Hour * 24), // 24 hour sessions
|
||||
encryptionKey: key,
|
||||
}
|
||||
|
||||
// Add default basic auth provider
|
||||
basicProvider := NewBasicAuthProvider(userStorage)
|
||||
sm.AddAuthProvider(basicProvider)
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// NewRoleManager creates a new role manager
|
||||
func NewRoleManager() *RoleManager {
|
||||
rm := &RoleManager{
|
||||
roles: make(map[string]*Role),
|
||||
permissions: make(map[string]*Permission),
|
||||
func NewRoleManager(roleStorage RoleStorage, permissionStorage PermissionStorage) *RoleManager {
|
||||
return &RoleManager{
|
||||
roleStorage: roleStorage,
|
||||
permissionStorage: permissionStorage,
|
||||
}
|
||||
|
||||
// Initialize default permissions
|
||||
rm.AddPermission(&Permission{
|
||||
Name: "task.publish",
|
||||
Resource: "task",
|
||||
Action: "publish",
|
||||
Description: "Publish tasks to queues",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
rm.AddPermission(&Permission{
|
||||
Name: "task.consume",
|
||||
Resource: "task",
|
||||
Action: "consume",
|
||||
Description: "Consume tasks from queues",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
rm.AddPermission(&Permission{
|
||||
Name: "queue.manage",
|
||||
Resource: "queue",
|
||||
Action: "manage",
|
||||
Description: "Manage queues",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
rm.AddPermission(&Permission{
|
||||
Name: "admin.system",
|
||||
Resource: "system",
|
||||
Action: "admin",
|
||||
Description: "System administration",
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
// Initialize default roles
|
||||
rm.AddRole(&Role{
|
||||
Name: "publisher",
|
||||
Description: "Can publish tasks",
|
||||
Permissions: []string{"task.publish"},
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
rm.AddRole(&Role{
|
||||
Name: "consumer",
|
||||
Description: "Can consume tasks",
|
||||
Permissions: []string{"task.consume"},
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
rm.AddRole(&Role{
|
||||
Name: "admin",
|
||||
Description: "Full system access",
|
||||
Permissions: []string{"task.publish", "task.consume", "queue.manage", "admin.system"},
|
||||
CreatedAt: time.Now(),
|
||||
})
|
||||
|
||||
return rm
|
||||
}
|
||||
|
||||
// AddPermission adds a permission to the role manager
|
||||
func (rm *RoleManager) AddPermission(perm *Permission) {
|
||||
func (rm *RoleManager) AddPermission(perm *Permission) error {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.permissions[perm.Name] = perm
|
||||
return rm.permissionStorage.SavePermission(perm)
|
||||
}
|
||||
|
||||
// AddRole adds a role to the role manager
|
||||
func (rm *RoleManager) AddRole(role *Role) {
|
||||
func (rm *RoleManager) AddRole(role *Role) error {
|
||||
rm.mu.Lock()
|
||||
defer rm.mu.Unlock()
|
||||
rm.roles[role.Name] = role
|
||||
return rm.roleStorage.SaveRole(role)
|
||||
}
|
||||
|
||||
// HasPermission checks if a user has a specific permission
|
||||
@@ -218,11 +483,13 @@ func (rm *RoleManager) HasPermission(user *User, permission string) bool {
|
||||
defer rm.mu.RUnlock()
|
||||
|
||||
for _, roleName := range user.Roles {
|
||||
if role, exists := rm.roles[roleName]; exists {
|
||||
for _, perm := range role.Permissions {
|
||||
if perm == permission {
|
||||
return true
|
||||
}
|
||||
role, err := rm.roleStorage.GetRole(roleName)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, perm := range role.Permissions {
|
||||
if perm == permission {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -236,10 +503,12 @@ func (rm *RoleManager) GetUserPermissions(user *User) []string {
|
||||
|
||||
permissions := make(map[string]bool)
|
||||
for _, roleName := range user.Roles {
|
||||
if role, exists := rm.roles[roleName]; exists {
|
||||
for _, perm := range role.Permissions {
|
||||
permissions[perm] = true
|
||||
}
|
||||
role, err := rm.roleStorage.GetRole(roleName)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, perm := range role.Permissions {
|
||||
permissions[perm] = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -469,6 +738,60 @@ func (sm *SecurityManager) Authenticate(ctx context.Context, credentials map[str
|
||||
return nil, fmt.Errorf("authentication failed: %w", lastErr)
|
||||
}
|
||||
|
||||
func (sm *SecurityManager) AddPermission(perm *Permission) error {
|
||||
if perm == nil || (perm.Name == "" && (perm.Resource == "" || perm.Action == "")) {
|
||||
return fmt.Errorf("invalid permission")
|
||||
}
|
||||
return sm.roleManager.AddPermission(perm)
|
||||
}
|
||||
|
||||
func (sm *SecurityManager) AddRole(role *Role) error {
|
||||
if role == nil || role.Name == "" {
|
||||
return fmt.Errorf("invalid role")
|
||||
}
|
||||
return sm.roleManager.AddRole(role)
|
||||
}
|
||||
|
||||
func (sm *SecurityManager) AddUsers(users ...*User) error {
|
||||
for _, user := range users {
|
||||
if user == nil || user.Username == "" || user.Password == "" {
|
||||
return fmt.Errorf("invalid user")
|
||||
}
|
||||
if err := sm.AddUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SecurityManager) AddRoles(roles ...*Role) error {
|
||||
for _, role := range roles {
|
||||
if err := sm.AddRole(role); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SecurityManager) AddPermissions(perms ...*Permission) error {
|
||||
for _, perm := range perms {
|
||||
if err := sm.AddPermission(perm); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddUser adds a user to the system
|
||||
func (sm *SecurityManager) AddUser(user *User) error {
|
||||
for _, provider := range sm.authProviders {
|
||||
if bap, ok := provider.(*BasicAuthProvider); ok {
|
||||
return bap.AddUser(user, user.Password)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("no suitable auth provider found")
|
||||
}
|
||||
|
||||
// Authorize checks if a user is authorized for an action
|
||||
func (sm *SecurityManager) Authorize(user *User, resource, action string) error {
|
||||
permission := fmt.Sprintf("%s.%s", resource, action)
|
||||
@@ -551,13 +874,12 @@ func (sm *SecurityManager) Decrypt(data []byte) ([]byte, error) {
|
||||
|
||||
// BasicAuthProvider implements basic username/password authentication
|
||||
type BasicAuthProvider struct {
|
||||
users map[string]*User
|
||||
mu sync.RWMutex
|
||||
userStorage UserStorage
|
||||
}
|
||||
|
||||
func NewBasicAuthProvider() *BasicAuthProvider {
|
||||
func NewBasicAuthProvider(userStorage UserStorage) *BasicAuthProvider {
|
||||
return &BasicAuthProvider{
|
||||
users: make(map[string]*User),
|
||||
userStorage: userStorage,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -576,16 +898,13 @@ func (bap *BasicAuthProvider) Authenticate(ctx context.Context, credentials map[
|
||||
return nil, fmt.Errorf("password required")
|
||||
}
|
||||
|
||||
bap.mu.RLock()
|
||||
user, exists := bap.users[username]
|
||||
bap.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
user, err := bap.userStorage.GetUser(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
// In production, compare hashed passwords
|
||||
if password != "password" { // Placeholder
|
||||
storedPassword, err := bap.userStorage.GetPassword(username)
|
||||
if err != nil || storedPassword != password {
|
||||
return nil, fmt.Errorf("invalid password")
|
||||
}
|
||||
|
||||
@@ -611,14 +930,9 @@ func (bap *BasicAuthProvider) ValidateToken(token string) (*User, error) {
|
||||
}
|
||||
|
||||
func (bap *BasicAuthProvider) AddUser(user *User, password string) error {
|
||||
bap.mu.Lock()
|
||||
defer bap.mu.Unlock()
|
||||
|
||||
// In production, hash the password
|
||||
user.CreatedAt = time.Now()
|
||||
bap.users[user.Username] = user
|
||||
|
||||
return nil
|
||||
return bap.userStorage.SaveUser(user, password)
|
||||
}
|
||||
|
||||
// generateID generates a random ID
|
||||
|
Reference in New Issue
Block a user