diff --git a/README.md b/README.md index 907aa52..95b8cab 100644 --- a/README.md +++ b/README.md @@ -465,22 +465,24 @@ accessToken, _ := oauth2Server.ExchangeCodeForToken( Listen to authentication and authorization events for audit logging, security monitoring, etc: ```go -// Create event manager -eventMgr := core.NewEventManager() +storage := memory.NewStorage() + +manager := core.NewBuilder(). + Storage(storage). + Build() // Listen to login events -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { fmt.Printf("[LOGIN] User: %s, Token: %s\n", data.LoginID, data.Token) - // Log audit, send notifications, etc. }) // Listen to logout events -eventMgr.RegisterFunc(core.EventLogout, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogout, func(data *core.EventData) { fmt.Printf("[LOGOUT] User: %s\n", data.LoginID) }) // Advanced: priority and sync execution -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(auditLogger), core.ListenerConfig{ Priority: 100, // High priority @@ -489,9 +491,15 @@ eventMgr.RegisterWithConfig(core.EventLogin, ) // Listen to all events (wildcard) -eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { +manager.RegisterFunc(core.EventAll, func(data *core.EventData) { log.Printf("[%s] %s", data.Event, data.LoginID) }) + +// Access advanced controls via the underlying EventManager +manager.GetEventManager().SetPanicHandler(customPanicHandler) + +// Use the manager globally +stputil.SetManager(manager) ``` **Available events:** diff --git a/README_zh.md b/README_zh.md index 0c09df4..cfe5ae3 100644 --- a/README_zh.md +++ b/README_zh.md @@ -466,22 +466,25 @@ accessToken, _ := oauth2Server.ExchangeCodeForToken( 监听认证和授权事件,实现审计日志、安全监控等功能: ```go -// 创建事件管理器 -eventMgr := core.NewEventManager() +storage := memory.NewStorage() + +manager := core.NewBuilder(). + Storage(storage). + Build() // 监听登录事件 -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { fmt.Printf("[LOGIN] User: %s, Token: %s\n", data.LoginID, data.Token) // 记录审计日志、发送通知等 }) -// 监听注销事件 -eventMgr.RegisterFunc(core.EventLogout, func(data *core.EventData) { +// 监听登出事件 +manager.RegisterFunc(core.EventLogout, func(data *core.EventData) { fmt.Printf("[LOGOUT] User: %s\n", data.LoginID) }) // 高级特性:优先级、同步执行 -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(auditLogger), core.ListenerConfig{ Priority: 100, // 高优先级 @@ -490,9 +493,15 @@ eventMgr.RegisterWithConfig(core.EventLogin, ) // 监听所有事件(通配符) -eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { +manager.RegisterFunc(core.EventAll, func(data *core.EventData) { log.Printf("[%s] %s", data.Event, data.LoginID) }) + +// 可通过底层 EventManager 访问更多控制能力 +manager.GetEventManager().SetPanicHandler(customPanicHandler) + +// 设置全局管理器 +stputil.SetManager(manager) ``` **可用事件:** diff --git a/core/manager/manager.go b/core/manager/manager.go index 2fa4deb..a257217 100644 --- a/core/manager/manager.go +++ b/core/manager/manager.go @@ -7,6 +7,7 @@ import ( "github.com/click33/sa-token-go/core/adapter" "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/core/listener" "github.com/click33/sa-token-go/core/oauth2" "github.com/click33/sa-token-go/core/security" "github.com/click33/sa-token-go/core/session" @@ -63,6 +64,7 @@ type Manager struct { nonceManager *security.NonceManager refreshManager *security.RefreshTokenManager oauth2Server *oauth2.OAuth2Server + eventManager *listener.Manager } // NewManager Creates a new manager | 创建管理器 @@ -85,6 +87,7 @@ func NewManager(storage adapter.Storage, cfg *config.Config) *Manager { nonceManager: security.NewNonceManager(storage, prefix, DefaultNonceTTL), refreshManager: security.NewRefreshTokenManager(storage, prefix, cfg), oauth2Server: oauth2.NewOAuth2Server(storage, prefix), + eventManager: listener.NewManager(), } } @@ -154,6 +157,16 @@ func (m *Manager) Login(loginID string, device ...string) (string, error) { sess.Set(SessionKeyDevice, deviceType) sess.Set(SessionKeyLoginTime, time.Now().Unix()) + // Trigger login event | 触发登录事件 + if m.eventManager != nil { + m.eventManager.Trigger(&listener.EventData{ + Event: listener.EventLogin, + LoginID: loginID, + Token: tokenValue, + Device: deviceType, + }) + } + return tokenValue, nil } @@ -194,6 +207,15 @@ func (m *Manager) Logout(loginID string, device ...string) error { // Delete account mapping | 删除账号映射 m.storage.Delete(accountKey) + // Trigger logout event | 触发登出事件 + if m.eventManager != nil { + m.eventManager.Trigger(&listener.EventData{ + Event: listener.EventLogout, + LoginID: loginID, + Device: deviceType, + }) + } + return nil } @@ -202,8 +224,23 @@ func (m *Manager) LogoutByToken(tokenValue string) error { if tokenValue == "" { return nil } + + // Get loginID before deletion for event | 删除前获取loginID用于事件 + loginID, _ := m.getLoginIDByToken(tokenValue) + tokenKey := m.getTokenKey(tokenValue) - return m.storage.Delete(tokenKey) + err := m.storage.Delete(tokenKey) + + // Trigger logout event | 触发登出事件 + if m.eventManager != nil && loginID != "" { + m.eventManager.Trigger(&listener.EventData{ + Event: listener.EventLogout, + LoginID: loginID, + Token: tokenValue, + }) + } + + return err } // kickout Kick user offline (private) | 踢人下线(私有) @@ -220,6 +257,17 @@ func (m *Manager) kickout(loginID string, device string) error { } tokenKey := m.getTokenKey(tokenStr) + + // Trigger kickout event | 触发踢出事件 + if m.eventManager != nil { + m.eventManager.Trigger(&listener.EventData{ + Event: listener.EventKickout, + LoginID: loginID, + Token: tokenStr, + Device: device, + }) + } + return m.storage.Delete(tokenKey) } @@ -638,6 +686,58 @@ func (m *Manager) toStringSlice(v any) []string { } } +// ============ Event Management | 事件管理 ============ + +// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 +func (m *Manager) RegisterFunc(event listener.Event, fn func(*listener.EventData)) { + if m.eventManager != nil { + m.eventManager.RegisterFunc(event, fn) + } +} + +// Register registers an event listener | 注册事件监听器 +func (m *Manager) Register(event listener.Event, listener listener.Listener) string { + if m.eventManager != nil { + return m.eventManager.Register(event, listener) + } + return "" +} + +// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 +func (m *Manager) RegisterWithConfig(event listener.Event, listener listener.Listener, config listener.ListenerConfig) string { + if m.eventManager != nil { + return m.eventManager.RegisterWithConfig(event, listener, config) + } + return "" +} + +// Unregister removes an event listener by ID | 根据ID移除事件监听器 +func (m *Manager) Unregister(id string) bool { + if m.eventManager != nil { + return m.eventManager.Unregister(id) + } + return false +} + +// TriggerEvent manually triggers an event | 手动触发事件 +func (m *Manager) TriggerEvent(data *listener.EventData) { + if m.eventManager != nil { + m.eventManager.Trigger(data) + } +} + +// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 +func (m *Manager) WaitEvents() { + if m.eventManager != nil { + m.eventManager.Wait() + } +} + +// GetEventManager gets the event manager | 获取事件管理器 +func (m *Manager) GetEventManager() *listener.Manager { + return m.eventManager +} + // ============ Public Getters | 公共获取器 ============ // GetConfig Gets configuration | 获取配置 diff --git a/core/security/refresh_token.go b/core/security/refresh_token.go index 095b15a..f6d21a0 100644 --- a/core/security/refresh_token.go +++ b/core/security/refresh_token.go @@ -3,6 +3,7 @@ package security import ( "crypto/rand" "encoding/hex" + "encoding/json" "fmt" "time" @@ -42,12 +43,22 @@ var ( // RefreshTokenInfo refresh token information | 刷新令牌信息 type RefreshTokenInfo struct { - RefreshToken string // Refresh token (long-lived) | 刷新令牌(长期有效) - AccessToken string // Access token (short-lived) | 访问令牌(短期有效) - LoginID string // User login ID | 用户登录ID - Device string // Device type | 设备类型 - CreateTime int64 // Creation timestamp | 创建时间戳 - ExpireTime int64 // Expiration timestamp | 过期时间戳 + RefreshToken string `json:"refreshToken"` // Refresh token (long-lived) | 刷新令牌(长期有效) + AccessToken string `json:"accessToken"` // Access token (short-lived) | 访问令牌(短期有效) + LoginID string `json:"loginID"` // User login ID | 用户登录ID + Device string `json:"device"` // Device type | 设备类型 + CreateTime int64 `json:"createTime"` // Creation timestamp | 创建时间戳 + ExpireTime int64 `json:"expireTime"` // Expiration timestamp | 过期时间戳 +} + +// MarshalBinary implements encoding.BinaryMarshaler for Redis storage | 实现encoding.BinaryMarshaler接口用于Redis存储 +func (r *RefreshTokenInfo) MarshalBinary() ([]byte, error) { + return json.Marshal(r) +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler for Redis storage | 实现encoding.BinaryUnmarshaler接口用于Redis存储 +func (r *RefreshTokenInfo) UnmarshalBinary(data []byte) error { + return json.Unmarshal(data, r) } // RefreshTokenManager Refresh token manager | 刷新令牌管理器 diff --git a/docs/guide/listener.md b/docs/guide/listener.md index 2a62623..dc2b65d 100644 --- a/docs/guide/listener.md +++ b/docs/guide/listener.md @@ -22,24 +22,32 @@ Sa-Token-Go provides a powerful event system for monitoring authentication and a ## Basic Usage -### Create Event Manager +### Create Manager with Event Support ```go -import "github.com/click33/sa-token-go/core" +import ( + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/storage/memory" +) -eventMgr := core.NewEventManager() +manager := core.NewBuilder(). + Storage(memory.NewStorage()). + Build() + +// Optional: direct access to advanced controls +eventMgr := manager.GetEventManager() ``` ### Register Listener (Function) ```go // Listen to login event -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { fmt.Printf("[LOGIN] User: %s, Token: %s\n", data.LoginID, data.Token) }) // Listen to logout event -eventMgr.RegisterFunc(core.EventLogout, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogout, func(data *core.EventData) { fmt.Printf("[LOGOUT] User: %s\n", data.LoginID) }) ``` @@ -54,7 +62,7 @@ func (a *AuditLogger) OnEvent(data *core.EventData) { fmt.Printf("[AUDIT] Event: %s, User: %s\n", data.Event, data.LoginID) } -eventMgr.Register(core.EventLogin, &AuditLogger{}) +manager.Register(core.EventLogin, &AuditLogger{}) ``` ## Advanced Features @@ -63,7 +71,7 @@ eventMgr.Register(core.EventLogin, &AuditLogger{}) ```go // Higher priority listeners execute first -eventMgr.RegisterWithConfig( +manager.RegisterWithConfig( core.EventLogin, myListener, core.ListenerConfig{ @@ -76,7 +84,7 @@ eventMgr.RegisterWithConfig( ```go // Execute synchronously (blocking) -eventMgr.RegisterWithConfig( +manager.RegisterWithConfig( core.EventLogin, myListener, core.ListenerConfig{ @@ -89,7 +97,7 @@ eventMgr.RegisterWithConfig( ```go // Listen to all events -eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { +manager.RegisterFunc(core.EventAll, func(data *core.EventData) { fmt.Printf("[ALL] Event: %s, User: %s\n", data.Event, data.LoginID) }) ``` @@ -98,7 +106,7 @@ eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { ```go // Register and get ID -id := eventMgr.RegisterWithConfig( +id := manager.RegisterWithConfig( core.EventLogin, myListener, core.ListenerConfig{ @@ -107,7 +115,7 @@ id := eventMgr.RegisterWithConfig( ) // Unregister by ID -eventMgr.Unregister(id) +manager.Unregister(id) ``` ## Use Cases @@ -115,7 +123,7 @@ eventMgr.Unregister(id) ### Audit Logging ```go -eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { +manager.RegisterFunc(core.EventAll, func(data *core.EventData) { log.Printf("[AUDIT] %s - User: %s, IP: %s, Time: %d", data.Event, data.LoginID, data.Extra["ip"], data.Timestamp) }) @@ -124,7 +132,7 @@ eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { ### Security Monitoring ```go -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { // Check for suspicious login // Send alert if needed }) @@ -133,7 +141,7 @@ eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { ### Session Analytics ```go -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { // Track active users // Update analytics }) diff --git a/docs/guide/listener_zh.md b/docs/guide/listener_zh.md index addfac9..75c716b 100644 --- a/docs/guide/listener_zh.md +++ b/docs/guide/listener_zh.md @@ -39,35 +39,32 @@ The event system allows you to: ## Basic Usage -### 1. Create an Event Manager +### 1. 创建带事件功能的 Manager ```go -import "github.com/click33/sa-token-go/core" +import ( + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/storage/memory" +) -eventManager := core.NewEventManager() -``` - -### 2. Register a Simple Listener - -```go -// Function-based listener -eventManager.RegisterFunc(core.EventLogin, func(data *core.EventData) { - fmt.Printf("User logged in: %s\n", data.LoginID) -}) -``` - -### 3. Register with the Manager - -```go manager := core.NewBuilder(). Storage(memory.NewStorage()). Build() -// Set event manager (if your Manager supports it) -// Note: You may need to integrate this into your Manager initialization +// 如需高级控制,可以获取底层事件管理器 +eventMgr := manager.GetEventManager() ``` -### 4. Complete Example +### 2. 注册简单监听器 + +```go +// 基于函数的监听器 +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { + fmt.Printf("User logged in: %s\n", data.LoginID) +}) +``` + +### 3. 完整示例 ```go package main @@ -80,26 +77,24 @@ import ( ) func main() { - // Create event manager - eventMgr := core.NewEventManager() + // Create manager with default event support + manager := core.NewBuilder(). + Storage(memory.NewStorage()). + Build() // Register login listener - eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { + manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { fmt.Printf("[LOGIN] User: %s, Token: %s, Device: %s\n", data.LoginID, data.Token, data.Device) }) // Register logout listener - eventMgr.RegisterFunc(core.EventLogout, func(data *core.EventData) { + manager.RegisterFunc(core.EventLogout, func(data *core.EventData) { fmt.Printf("[LOGOUT] User: %s\n", data.LoginID) }) // Initialize StpUtil - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - Build(), - ) + stputil.SetManager(manager) // Perform login (will trigger event) token, _ := stputil.Login(1000) @@ -118,7 +113,7 @@ Control the execution order of listeners using priorities: ```go // High priority listener (executes first) -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { fmt.Println("High priority listener") }), @@ -129,7 +124,7 @@ eventMgr.RegisterWithConfig(core.EventLogin, ) // Low priority listener (executes later) -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { fmt.Println("Low priority listener") }), @@ -144,7 +139,7 @@ eventMgr.RegisterWithConfig(core.EventLogin, ```go // Synchronous listener (blocks until complete) -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { // Critical operation that must complete before continuing saveToDatabase(data) @@ -155,7 +150,7 @@ eventMgr.RegisterWithConfig(core.EventLogin, ) // Asynchronous listener (non-blocking) -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { // Non-critical operation (logging, analytics) sendToAnalytics(data) @@ -170,7 +165,7 @@ eventMgr.RegisterWithConfig(core.EventLogin, ```go // Register with a custom ID -listenerID := eventMgr.RegisterWithConfig(core.EventLogin, +listenerID := manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { fmt.Println("Temporary listener") }), @@ -180,7 +175,7 @@ listenerID := eventMgr.RegisterWithConfig(core.EventLogin, ) // Later, unregister by ID -eventMgr.Unregister(listenerID) +manager.Unregister(listenerID) ``` ### Wildcard Listeners @@ -189,7 +184,7 @@ Listen to all events: ```go // Listen to all events -eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { +manager.RegisterFunc(core.EventAll, func(data *core.EventData) { fmt.Printf("[%s] LoginID: %s\n", data.Event, data.LoginID) }) ``` @@ -230,7 +225,7 @@ stputil.Login(1000) stputil.Login(2000) // Wait for all async listeners to complete -eventMgr.Wait() +manager.WaitEvents() ``` ## Best Practices @@ -239,7 +234,7 @@ eventMgr.Wait() ```go // ✅ Good: Async for logging -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { logToFile(data) // Can be async }), @@ -247,7 +242,7 @@ eventMgr.RegisterWithConfig(core.EventLogin, ) // ❌ Avoid: Sync for slow operations -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { sendEmail(data) // Slow operation blocks login }), @@ -259,12 +254,12 @@ eventMgr.RegisterWithConfig(core.EventLogin, ```go // ✅ Good: Quick processing -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { counter.Increment("login_count") }) // ❌ Avoid: Heavy processing -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { processLargeDataset() // This should be in a background job }) ``` @@ -273,13 +268,13 @@ eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { ```go // Validation (high priority) -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, validationListener, core.ListenerConfig{Priority: 100}, ) // Logging (low priority) -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, loggingListener, core.ListenerConfig{Priority: 10}, ) @@ -288,7 +283,7 @@ eventMgr.RegisterWithConfig(core.EventLogin, ### 4. Handle Errors Gracefully ```go -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { defer func() { if r := recover(); r != nil { log.Printf("Listener error: %v", r) @@ -317,7 +312,7 @@ func (a *AuditLogger) OnEvent(data *core.EventData) { // Usage logger := &AuditLogger{file: logFile} -eventMgr.Register(core.EventAll, logger) +manager.Register(core.EventAll, logger) ``` ### Example 2: Security Monitor @@ -338,14 +333,14 @@ func (s *SecurityMonitor) OnEvent(data *core.EventData) { // Usage monitor := &SecurityMonitor{alertChan: make(chan string, 100)} -eventMgr.Register(core.EventKickout, monitor) -eventMgr.Register(core.EventDisable, monitor) +manager.Register(core.EventKickout, monitor) +manager.Register(core.EventDisable, monitor) ``` ### Example 3: Login Counter with Redis ```go -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { // Increment daily login counter key := fmt.Sprintf("login:count:%s", time.Now().Format("2006-01-02")) redisClient.Incr(ctx, key) @@ -360,7 +355,7 @@ eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { ### Example 4: Multi-Factor Authentication ```go -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { // Check if MFA is required if requiresMFA(data.LoginID) { @@ -404,8 +399,8 @@ func (s *SessionAnalytics) OnEvent(data *core.EventData) { // Usage analytics := &SessionAnalytics{sessions: make(map[string]time.Time)} -eventMgr.Register(core.EventCreateSession, analytics) -eventMgr.Register(core.EventDestroySession, analytics) +manager.Register(core.EventCreateSession, analytics) +manager.Register(core.EventDestroySession, analytics) ``` ## EventData Structure @@ -424,7 +419,7 @@ type EventData struct { ### Accessing Extra Data ```go -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { if ipAddr, ok := data.Extra["ip_address"].(string); ok { fmt.Printf("Login from IP: %s\n", ipAddr) } @@ -441,8 +436,8 @@ All event manager operations are thread-safe: ```go // Safe to call from multiple goroutines -go eventMgr.RegisterFunc(core.EventLogin, handler1) -go eventMgr.RegisterFunc(core.EventLogin, handler2) +go manager.RegisterFunc(core.EventLogin, handler1) +go manager.RegisterFunc(core.EventLogin, handler2) go eventMgr.Trigger(&core.EventData{Event: core.EventLogin}) ``` diff --git a/docs/guide/refresh-token.md b/docs/guide/refresh-token.md index 8635bc9..294eb8e 100644 --- a/docs/guide/refresh-token.md +++ b/docs/guide/refresh-token.md @@ -289,7 +289,7 @@ if REFRESH_TOKEN_ROTATION { ```go // Log refresh events -eventMgr.RegisterFunc(core.EventRefresh, func(data *core.EventData) { +manager.RegisterFunc(core.EventRefresh, func(data *core.EventData) { // Detect abnormal refresh patterns if isAbnormalRefreshPattern(data.LoginID) { alert("Possible token leak") diff --git a/docs/guide/refresh-token_zh.md b/docs/guide/refresh-token_zh.md index 216b76e..c0f1fbc 100644 --- a/docs/guide/refresh-token_zh.md +++ b/docs/guide/refresh-token_zh.md @@ -289,7 +289,7 @@ if REFRESH_TOKEN_ROTATION { ```go // 记录刷新事件 -eventMgr.RegisterFunc(core.EventRefresh, func(data *core.EventData) { +manager.RegisterFunc(core.EventRefresh, func(data *core.EventData) { // 检测异常刷新模式 if isAbnormalRefreshPattern(data.LoginID) { alert("可能的令牌泄露") diff --git a/examples/listener-example/README.md b/examples/listener-example/README.md index fd815f0..aa14290 100644 --- a/examples/listener-example/README.md +++ b/examples/listener-example/README.md @@ -59,12 +59,22 @@ Remaining listeners: 4 ## Key Concepts +In this example the authentication manager automatically owns an internal event manager: + +```go +manager := core.NewBuilder(). + Storage(memory.NewStorage()). + Build() + +eventMgr := manager.GetEventManager() // Advanced controls (stats, enable/disable, panic handler, ...) +``` + ### Function Listeners The simplest way to register an event handler: ```go -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { fmt.Printf("User %s logged in\n", data.LoginID) }) ``` @@ -74,7 +84,7 @@ eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { Control execution order with priorities: ```go -eventMgr.RegisterWithConfig(core.EventLogin, +manager.RegisterWithConfig(core.EventLogin, myListener, core.ListenerConfig{ Priority: 100, // Higher = executes first @@ -88,7 +98,7 @@ eventMgr.RegisterWithConfig(core.EventLogin, Listen to all events: ```go -eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { +manager.RegisterFunc(core.EventAll, func(data *core.EventData) { // This will be called for every event }) ``` @@ -99,12 +109,12 @@ Add and remove listeners at runtime: ```go // Register with custom ID -id := eventMgr.RegisterWithConfig(event, listener, core.ListenerConfig{ +id := manager.RegisterWithConfig(event, listener, core.ListenerConfig{ ID: "my-listener", }) // Later, unregister -eventMgr.Unregister(id) +manager.Unregister(id) ``` ## Use Cases @@ -112,7 +122,7 @@ eventMgr.Unregister(id) ### 1. Audit Logging ```go -eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { +manager.RegisterFunc(core.EventAll, func(data *core.EventData) { auditLog.Write(fmt.Sprintf("[%s] %s - %s", data.Event, data.LoginID, time.Unix(data.Timestamp, 0))) }) @@ -121,7 +131,7 @@ eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { ### 2. Security Monitoring ```go -eventMgr.RegisterFunc(core.EventKickout, func(data *core.EventData) { +manager.RegisterFunc(core.EventKickout, func(data *core.EventData) { alertSystem.Send(fmt.Sprintf("User %s was kicked out", data.LoginID)) }) ``` @@ -129,7 +139,7 @@ eventMgr.RegisterFunc(core.EventKickout, func(data *core.EventData) { ### 3. Analytics ```go -eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { analytics.Track("user_login", map[string]interface{}{ "user_id": data.LoginID, "device": data.Device, @@ -140,7 +150,7 @@ eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { ### 4. Cache Invalidation ```go -eventMgr.RegisterFunc(core.EventLogout, func(data *core.EventData) { +manager.RegisterFunc(core.EventLogout, func(data *core.EventData) { cache.Delete("user:" + data.LoginID) }) ``` diff --git a/examples/listener-example/main.go b/examples/listener-example/main.go index a60d3f6..85c052a 100644 --- a/examples/listener-example/main.go +++ b/examples/listener-example/main.go @@ -12,26 +12,29 @@ import ( func main() { fmt.Println("=== Sa-Token-Go Event Listener Example ===\n") - // Create event manager - eventMgr := core.NewEventManager() - // 1. Simple function listener - eventMgr.RegisterFunc(core.EventLogin, func(data *core.EventData) { + manager := core.NewBuilder(). + Storage(memory.NewStorage()). + TokenName("Authorization"). + Timeout(7200). + Build() + + manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { fmt.Printf("[LOGIN] User %s logged in with token %s\n", data.LoginID, data.Token[:20]+"...") }) // 2. Logout listener - eventMgr.RegisterFunc(core.EventLogout, func(data *core.EventData) { + manager.RegisterFunc(core.EventLogout, func(data *core.EventData) { fmt.Printf("[LOGOUT] User %s logged out\n", data.LoginID) }) // 3. Kickout listener - eventMgr.RegisterFunc(core.EventKickout, func(data *core.EventData) { + manager.RegisterFunc(core.EventKickout, func(data *core.EventData) { fmt.Printf("[KICKOUT] User %s was forcibly logged out\n", data.LoginID) }) // 4. High-priority synchronous listener - eventMgr.RegisterWithConfig(core.EventLogin, + auditListenerID := manager.RegisterWithConfig(core.EventLogin, core.ListenerFunc(func(data *core.EventData) { fmt.Printf("[AUDIT] Login audit - User: %s, Time: %d\n", data.LoginID, data.Timestamp) @@ -44,23 +47,19 @@ func main() { ) // 5. Wildcard listener (all events) - eventMgr.RegisterFunc(core.EventAll, func(data *core.EventData) { + manager.RegisterFunc(core.EventAll, func(data *core.EventData) { fmt.Printf("[ALL EVENTS] %s\n", data.String()) }) + eventMgr := manager.GetEventManager() + // 6. Custom panic handler eventMgr.SetPanicHandler(func(event core.Event, data *core.EventData, recovered interface{}) { fmt.Printf("[PANIC RECOVERED] Event: %s, Error: %v\n", event, recovered) }) // Initialize Sa-Token - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenName("Authorization"). - Timeout(7200). - Build(), - ) + stputil.SetManager(manager) fmt.Println("\n--- Triggering Events ---\n") @@ -80,7 +79,7 @@ func main() { time.Sleep(100 * time.Millisecond) // Wait for all async listeners to complete - eventMgr.Wait() + manager.WaitEvents() fmt.Println("\n--- Listener Statistics ---") fmt.Printf("Total listeners: %d\n", eventMgr.Count()) @@ -89,7 +88,7 @@ func main() { // Unregister a listener fmt.Println("\n--- Unregistering audit logger ---") - if eventMgr.Unregister("audit-logger") { + if manager.Unregister(auditListenerID) { fmt.Println("Audit logger unregistered successfully") } diff --git a/integrations/chi/annotation.go b/integrations/chi/annotation.go new file mode 100644 index 0000000..b8ec1d9 --- /dev/null +++ b/integrations/chi/annotation.go @@ -0,0 +1,131 @@ +package chi + +import ( + "net/http" + "strings" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/stputil" +) + +// Annotation annotation structure | 注解结构体 +type Annotation struct { + CheckLogin bool `json:"checkLogin"` + CheckRole []string `json:"checkRole"` + CheckPermission []string `json:"checkPermission"` + CheckDisable bool `json:"checkDisable"` + Ignore bool `json:"ignore"` +} + +// GetHandler gets handler with annotations | 获取带注解的处理器 +func GetHandler(handler http.Handler, annotations ...*Annotation) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if authentication should be ignored | 检查是否忽略认证 + if len(annotations) > 0 && annotations[0].Ignore { + if handler != nil { + handler.ServeHTTP(w, r) + } + return + } + + // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) + ctx := NewChiContext(w, r) + saCtx := core.NewContext(ctx, stputil.GetManager()) + token := saCtx.GetTokenValue() + if token == "" { + writeErrorResponse(w, core.NewNotLoginError()) + return + } + + // Check login | 检查登录 + if !stputil.IsLogin(token) { + writeErrorResponse(w, core.NewNotLoginError()) + return + } + + // Get login ID | 获取登录ID + loginID, err := stputil.GetLoginID(token) + if err != nil { + writeErrorResponse(w, err) + return + } + + // Check if account is disabled | 检查是否被封禁 + if len(annotations) > 0 && annotations[0].CheckDisable { + if stputil.IsDisable(loginID) { + writeErrorResponse(w, core.NewAccountDisabledError(loginID)) + return + } + } + + // Check permission | 检查权限 + if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { + hasPermission := false + for _, perm := range annotations[0].CheckPermission { + if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { + hasPermission = true + break + } + } + if !hasPermission { + writeErrorResponse(w, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + return + } + } + + // Check role | 检查角色 + if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { + hasRole := false + for _, role := range annotations[0].CheckRole { + if stputil.HasRole(loginID, strings.TrimSpace(role)) { + hasRole = true + break + } + } + if !hasRole { + writeErrorResponse(w, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + return + } + } + + // All checks passed, execute original handler | 所有检查通过,执行原函数 + if handler != nil { + handler.ServeHTTP(w, r) + } + }) +} + +// CheckLoginMiddleware decorator for login checking | 检查登录装饰器 +func CheckLoginMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return GetHandler(next, &Annotation{CheckLogin: true}) + } +} + +// CheckRoleMiddleware decorator for role checking | 检查角色装饰器 +func CheckRoleMiddleware(roles ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return GetHandler(next, &Annotation{CheckRole: roles}) + } +} + +// CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 +func CheckPermissionMiddleware(perms ...string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return GetHandler(next, &Annotation{CheckPermission: perms}) + } +} + +// CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 +func CheckDisableMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return GetHandler(next, &Annotation{CheckDisable: true}) + } +} + +// IgnoreMiddleware decorator to ignore authentication | 忽略认证装饰器 +func IgnoreMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return GetHandler(next, &Annotation{Ignore: true}) + } +} diff --git a/integrations/echo/annotation.go b/integrations/echo/annotation.go new file mode 100644 index 0000000..a8c9804 --- /dev/null +++ b/integrations/echo/annotation.go @@ -0,0 +1,126 @@ +package echo + +import ( + "strings" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/stputil" + "github.com/labstack/echo/v4" +) + +// Annotation annotation structure | 注解结构体 +type Annotation struct { + CheckLogin bool `json:"checkLogin"` + CheckRole []string `json:"checkRole"` + CheckPermission []string `json:"checkPermission"` + CheckDisable bool `json:"checkDisable"` + Ignore bool `json:"ignore"` +} + +// GetHandler gets handler with annotations | 获取带注解的处理器 +func GetHandler(handler echo.HandlerFunc, annotations ...*Annotation) echo.HandlerFunc { + return func(c echo.Context) error { + // Check if authentication should be ignored | 检查是否忽略认证 + if len(annotations) > 0 && annotations[0].Ignore { + if handler != nil { + return handler(c) + } + return nil + } + + // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) + ctx := NewEchoContext(c) + saCtx := core.NewContext(ctx, stputil.GetManager()) + token := saCtx.GetTokenValue() + if token == "" { + return writeErrorResponse(c, core.NewNotLoginError()) + } + + // Check login | 检查登录 + if !stputil.IsLogin(token) { + return writeErrorResponse(c, core.NewNotLoginError()) + } + + // Get login ID | 获取登录ID + loginID, err := stputil.GetLoginID(token) + if err != nil { + return writeErrorResponse(c, err) + } + + // Check if account is disabled | 检查是否被封禁 + if len(annotations) > 0 && annotations[0].CheckDisable { + if stputil.IsDisable(loginID) { + return writeErrorResponse(c, core.NewAccountDisabledError(loginID)) + } + } + + // Check permission | 检查权限 + if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { + hasPermission := false + for _, perm := range annotations[0].CheckPermission { + if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { + hasPermission = true + break + } + } + if !hasPermission { + return writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + } + } + + // Check role | 检查角色 + if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { + hasRole := false + for _, role := range annotations[0].CheckRole { + if stputil.HasRole(loginID, strings.TrimSpace(role)) { + hasRole = true + break + } + } + if !hasRole { + return writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + } + } + + // All checks passed, execute original handler | 所有检查通过,执行原函数 + if handler != nil { + return handler(c) + } + return nil + } +} + +// CheckLoginMiddleware decorator for login checking | 检查登录装饰器 +func CheckLoginMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{CheckLogin: true}) + } +} + +// CheckRoleMiddleware decorator for role checking | 检查角色装饰器 +func CheckRoleMiddleware(roles ...string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{CheckRole: roles}) + } +} + +// CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 +func CheckPermissionMiddleware(perms ...string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{CheckPermission: perms}) + } +} + +// CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 +func CheckDisableMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{CheckDisable: true}) + } +} + +// IgnoreMiddleware decorator to ignore authentication | 忽略认证装饰器 +func IgnoreMiddleware() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{Ignore: true}) + } +} diff --git a/integrations/fiber/annotation.go b/integrations/fiber/annotation.go new file mode 100644 index 0000000..0182710 --- /dev/null +++ b/integrations/fiber/annotation.go @@ -0,0 +1,116 @@ +package fiber + +import ( + "strings" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/stputil" + "github.com/gofiber/fiber/v2" +) + +// Annotation annotation structure | 注解结构体 +type Annotation struct { + CheckLogin bool `json:"checkLogin"` + CheckRole []string `json:"checkRole"` + CheckPermission []string `json:"checkPermission"` + CheckDisable bool `json:"checkDisable"` + Ignore bool `json:"ignore"` +} + +// GetHandler gets handler with annotations | 获取带注解的处理器 +func GetHandler(handler fiber.Handler, annotations ...*Annotation) fiber.Handler { + return func(c *fiber.Ctx) error { + // Check if authentication should be ignored | 检查是否忽略认证 + if len(annotations) > 0 && annotations[0].Ignore { + if handler != nil { + return handler(c) + } + return c.Next() + } + + // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) + ctx := NewFiberContext(c) + saCtx := core.NewContext(ctx, stputil.GetManager()) + token := saCtx.GetTokenValue() + if token == "" { + return writeErrorResponse(c, core.NewNotLoginError()) + } + + // Check login | 检查登录 + if !stputil.IsLogin(token) { + return writeErrorResponse(c, core.NewNotLoginError()) + } + + // Get login ID | 获取登录ID + loginID, err := stputil.GetLoginID(token) + if err != nil { + return writeErrorResponse(c, err) + } + + // Check if account is disabled | 检查是否被封禁 + if len(annotations) > 0 && annotations[0].CheckDisable { + if stputil.IsDisable(loginID) { + return writeErrorResponse(c, core.NewAccountDisabledError(loginID)) + } + } + + // Check permission | 检查权限 + if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { + hasPermission := false + for _, perm := range annotations[0].CheckPermission { + if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { + hasPermission = true + break + } + } + if !hasPermission { + return writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + } + } + + // Check role | 检查角色 + if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { + hasRole := false + for _, role := range annotations[0].CheckRole { + if stputil.HasRole(loginID, strings.TrimSpace(role)) { + hasRole = true + break + } + } + if !hasRole { + return writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + } + } + + // All checks passed, execute original handler | 所有检查通过,执行原函数 + if handler != nil { + return handler(c) + } + return c.Next() + } +} + +// CheckLoginMiddleware decorator for login checking | 检查登录装饰器 +func CheckLoginMiddleware() fiber.Handler { + return GetHandler(nil, &Annotation{CheckLogin: true}) +} + +// CheckRoleMiddleware decorator for role checking | 检查角色装饰器 +func CheckRoleMiddleware(roles ...string) fiber.Handler { + return GetHandler(nil, &Annotation{CheckRole: roles}) +} + +// CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 +func CheckPermissionMiddleware(perms ...string) fiber.Handler { + return GetHandler(nil, &Annotation{CheckPermission: perms}) +} + +// CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 +func CheckDisableMiddleware() fiber.Handler { + return GetHandler(nil, &Annotation{CheckDisable: true}) +} + +// IgnoreMiddleware decorator to ignore authentication | 忽略认证装饰器 +func IgnoreMiddleware() fiber.Handler { + return GetHandler(nil, &Annotation{Ignore: true}) +} diff --git a/integrations/gf/annotation.go b/integrations/gf/annotation.go new file mode 100644 index 0000000..d1e0630 --- /dev/null +++ b/integrations/gf/annotation.go @@ -0,0 +1,125 @@ +package gf + +import ( + "strings" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/stputil" + "github.com/gogf/gf/v2/net/ghttp" +) + +// Annotation annotation structure | 注解结构体 +type Annotation struct { + CheckLogin bool `json:"checkLogin"` + CheckRole []string `json:"checkRole"` + CheckPermission []string `json:"checkPermission"` + CheckDisable bool `json:"checkDisable"` + Ignore bool `json:"ignore"` +} + +// GetHandler gets handler with annotations | 获取带注解的处理器 +func GetHandler(handler ghttp.HandlerFunc, annotations ...*Annotation) ghttp.HandlerFunc { + return func(r *ghttp.Request) { + // Check if authentication should be ignored | 检查是否忽略认证 + if len(annotations) > 0 && annotations[0].Ignore { + if handler != nil { + handler(r) + } else { + r.Middleware.Next() + } + return + } + + // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) + ctx := NewGFContext(r) + saCtx := core.NewContext(ctx, stputil.GetManager()) + token := saCtx.GetTokenValue() + if token == "" { + writeErrorResponse(r, core.NewNotLoginError()) + return + } + + // Check login | 检查登录 + if !stputil.IsLogin(token) { + writeErrorResponse(r, core.NewNotLoginError()) + return + } + + // Get login ID | 获取登录ID + loginID, err := stputil.GetLoginID(token) + if err != nil { + writeErrorResponse(r, err) + return + } + + // Check if account is disabled | 检查是否被封禁 + if len(annotations) > 0 && annotations[0].CheckDisable { + if stputil.IsDisable(loginID) { + writeErrorResponse(r, core.NewAccountDisabledError(loginID)) + return + } + } + + // Check permission | 检查权限 + if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { + hasPermission := false + for _, perm := range annotations[0].CheckPermission { + if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { + hasPermission = true + break + } + } + if !hasPermission { + writeErrorResponse(r, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + return + } + } + + // Check role | 检查角色 + if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { + hasRole := false + for _, role := range annotations[0].CheckRole { + if stputil.HasRole(loginID, strings.TrimSpace(role)) { + hasRole = true + break + } + } + if !hasRole { + writeErrorResponse(r, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + return + } + } + + // All checks passed, execute original handler | 所有检查通过,执行原函数 + if handler != nil { + handler(r) + } else { + r.Middleware.Next() + } + } +} + +// CheckLoginMiddleware decorator for login checking | 检查登录装饰器 +func CheckLoginMiddleware() ghttp.HandlerFunc { + return GetHandler(nil, &Annotation{CheckLogin: true}) +} + +// CheckRoleMiddleware decorator for role checking | 检查角色装饰器 +func CheckRoleMiddleware(roles ...string) ghttp.HandlerFunc { + return GetHandler(nil, &Annotation{CheckRole: roles}) +} + +// CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 +func CheckPermissionMiddleware(perms ...string) ghttp.HandlerFunc { + return GetHandler(nil, &Annotation{CheckPermission: perms}) +} + +// CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 +func CheckDisableMiddleware() ghttp.HandlerFunc { + return GetHandler(nil, &Annotation{CheckDisable: true}) +} + +// IgnoreMiddleware decorator to ignore authentication | 忽略认证装饰器 +func IgnoreMiddleware() ghttp.HandlerFunc { + return GetHandler(nil, &Annotation{Ignore: true}) +} diff --git a/integrations/gin/annotation.go b/integrations/gin/annotation.go index bbfb336..cc7a1dc 100644 --- a/integrations/gin/annotation.go +++ b/integrations/gin/annotation.go @@ -1,10 +1,10 @@ package gin import ( - "net/http" "reflect" "strings" + "github.com/click33/sa-token-go/core" "github.com/click33/sa-token-go/stputil" ginfw "github.com/gin-gonic/gin" ) @@ -92,53 +92,43 @@ func GetHandler(handler interface{}, annotations ...*Annotation) ginfw.HandlerFu return func(c *ginfw.Context) { // Check if authentication should be ignored | 检查是否忽略认证 if len(annotations) > 0 && annotations[0].Ignore { - if handler != nil { - handler.(func(*ginfw.Context))(c) - } else { - c.Next() + if callHandler(handler, c) { + return } + c.Next() return } - // Get token | 获取Token - token := c.GetHeader("Authorization") + // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) + ctx := NewGinContext(c) + saCtx := core.NewContext(ctx, stputil.GetManager()) + token := saCtx.GetTokenValue() if token == "" { - token = c.GetHeader("satoken") - } - if token == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, ginfw.H{ - "code": 401, - "message": "未登录", - }) + writeErrorResponse(c, core.NewNotLoginError()) + c.Abort() return } // Check login | 检查登录 if !stputil.IsLogin(token) { - c.AbortWithStatusJSON(http.StatusUnauthorized, ginfw.H{ - "code": 401, - "message": "未登录", - }) + writeErrorResponse(c, core.NewNotLoginError()) + c.Abort() return } // Get login ID | 获取登录ID loginID, err := stputil.GetLoginID(token) if err != nil { - c.AbortWithStatusJSON(http.StatusUnauthorized, ginfw.H{ - "code": 401, - "message": "登录状态无效", - }) + writeErrorResponse(c, err) + c.Abort() return } // Check if account is disabled | 检查是否被封禁 if len(annotations) > 0 && annotations[0].CheckDisable { if stputil.IsDisable(loginID) { - c.AbortWithStatusJSON(http.StatusForbidden, ginfw.H{ - "code": 403, - "message": "账号已被封禁", - }) + writeErrorResponse(c, core.NewAccountDisabledError(loginID)) + c.Abort() return } } @@ -153,10 +143,8 @@ func GetHandler(handler interface{}, annotations ...*Annotation) ginfw.HandlerFu } } if !hasPermission { - c.AbortWithStatusJSON(http.StatusForbidden, ginfw.H{ - "code": 403, - "message": "权限不足", - }) + writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + c.Abort() return } } @@ -171,23 +159,54 @@ func GetHandler(handler interface{}, annotations ...*Annotation) ginfw.HandlerFu } } if !hasRole { - c.AbortWithStatusJSON(http.StatusForbidden, ginfw.H{ - "code": 403, - "message": "角色不足", - }) + writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + c.Abort() return } } // All checks passed, execute original handler or continue | 所有检查通过,执行原函数或继续 - if handler != nil { - handler.(func(*ginfw.Context))(c) - } else { - c.Next() + if callHandler(handler, c) { + return } + c.Next() } } +func callHandler(handler interface{}, c *ginfw.Context) bool { + if handler == nil { + return false + } + + switch h := handler.(type) { + case func(*ginfw.Context): + if h == nil { + return false + } + h(c) + return true + case ginfw.HandlerFunc: + if h == nil { + return false + } + h(c) + return true + } + + hv := reflect.ValueOf(handler) + if hv.Kind() != reflect.Func || hv.IsNil() || hv.Type().NumIn() != 1 { + return false + } + + argType := hv.Type().In(0) + if !argType.AssignableTo(reflect.TypeOf(c)) { + return false + } + + hv.Call([]reflect.Value{reflect.ValueOf(c)}) + return true +} + // Decorator functions | 装饰器函数 // CheckLogin decorator for login checking | 检查登录装饰器 @@ -277,45 +296,36 @@ func Middleware(annotations ...*Annotation) ginfw.HandlerFunc { return } - // 获取Token - token := c.GetHeader("Authorization") + // 获取Token(使用配置的TokenName) + ctx := NewGinContext(c) + saCtx := core.NewContext(ctx, stputil.GetManager()) + token := saCtx.GetTokenValue() if token == "" { - token = c.GetHeader("satoken") - } - if token == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, ginfw.H{ - "code": 401, - "message": "未登录", - }) + writeErrorResponse(c, core.NewNotLoginError()) + c.Abort() return } // 检查登录 if !stputil.IsLogin(token) { - c.AbortWithStatusJSON(http.StatusUnauthorized, ginfw.H{ - "code": 401, - "message": "未登录", - }) + writeErrorResponse(c, core.NewNotLoginError()) + c.Abort() return } // 获取登录ID loginID, err := stputil.GetLoginID(token) if err != nil { - c.AbortWithStatusJSON(http.StatusUnauthorized, ginfw.H{ - "code": 401, - "message": "登录状态无效", - }) + writeErrorResponse(c, err) + c.Abort() return } // 检查是否被封禁 if len(annotations) > 0 && annotations[0].CheckDisable { if stputil.IsDisable(loginID) { - c.AbortWithStatusJSON(http.StatusForbidden, ginfw.H{ - "code": 403, - "message": "账号已被封禁", - }) + writeErrorResponse(c, core.NewAccountDisabledError(loginID)) + c.Abort() return } } @@ -330,10 +340,8 @@ func Middleware(annotations ...*Annotation) ginfw.HandlerFunc { } } if !hasPermission { - c.AbortWithStatusJSON(http.StatusForbidden, ginfw.H{ - "code": 403, - "message": "权限不足", - }) + writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + c.Abort() return } } @@ -348,10 +356,8 @@ func Middleware(annotations ...*Annotation) ginfw.HandlerFunc { } } if !hasRole { - c.AbortWithStatusJSON(http.StatusForbidden, ginfw.H{ - "code": 403, - "message": "角色不足", - }) + writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + c.Abort() return } }