Files
sa-token-go/integrations/kratos/plugin.go
2025-11-21 23:51:34 +08:00

366 lines
8.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package kratos
import (
"context"
"sort"
"strings"
"github.com/click33/sa-token-go/core"
"github.com/go-kratos/kratos/v2/middleware"
"github.com/go-kratos/kratos/v2/transport"
)
// Plugin 认证引擎
type Plugin struct {
manager *core.Manager
rules []Rule
options *PluginOptions
}
// Rule 认证规则
type Rule struct {
// Matcher operation匹配器
Matcher OperationMatcher
// Checkers 检查器链
Checkers []Checker
// Priority 规则优先级数字越大优先级越高默认0
Priority int
}
// NewPlugin 创建认证插件
func NewPlugin(manager *core.Manager, opts ...*PluginOptions) *Plugin {
plugin := &Plugin{
manager: manager,
rules: make([]Rule, 0),
}
if len(opts) > 0 && opts[0] != nil {
plugin.options = opts[0]
} else {
plugin.options = defaultPluginOptions()
}
return plugin
}
// Server 返回Kratos中间件
func (e *Plugin) Server() middleware.Middleware {
return func(handler middleware.Handler) middleware.Handler {
return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
info, ok := transport.FromServerContext(ctx)
if !ok {
// 无法获取传输层信息,直接放行
return handler(ctx, req)
}
kratosContext := NewKratosContext(ctx)
saCtx := core.NewContext(kratosContext, e.manager)
operation := info.Operation()
if e.shouldSkip(operation) {
return handler(ctx, req)
}
rule, found := e.findRule(operation)
if !found {
if e.options.DefaultRequireLogin {
if !saCtx.IsLogin() {
return nil, e.options.ErrorHandler(ctx, core.ErrNotLogin)
}
}
ctx = context.WithValue(ctx, "satoken", saCtx)
return handler(ctx, req)
}
loginID, err := saCtx.GetLoginID()
if err != nil {
return nil, e.options.ErrorHandler(ctx, core.ErrNotLogin)
}
for _, checker := range rule.Checkers {
if err := checker.Check(ctx, e.manager, loginID); err != nil {
return nil, e.options.ErrorHandler(ctx, err)
}
}
ctx = context.WithValue(ctx, "satoken", saCtx)
return handler(ctx, req)
}
}
}
// ========== 规则构建器 ==========
// RuleBuilder 规则构建器链式API
type RuleBuilder struct {
plugin *Plugin
matcher OperationMatcher
checkers []Checker
priority int
}
// AutoMatcher 匹配指定operation自动选择matcher类型
func (e *Plugin) AutoMatcher(pattern string) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: newPatternMatcher(pattern),
checkers: make([]Checker, 0),
priority: 0,
}
}
// ExactMatcher 精确匹配
func (e *Plugin) ExactMatcher(operation string) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: &ExactMatcher{operation: operation},
checkers: make([]Checker, 0),
}
}
// PrefixMatcher 前缀匹配
func (e *Plugin) PrefixMatcher(prefix string) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: &PrefixMatcher{prefix: prefix},
checkers: make([]Checker, 0),
}
}
// SuffixMatcher 后缀匹配
func (e *Plugin) SuffixMatcher(suffix string) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: &SuffixMatcher{suffix: suffix},
checkers: make([]Checker, 0),
}
}
// PatternMatcher 通配符匹配
func (e *Plugin) PatternMatcher(pattern string) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: &WildcardMatcher{pattern: pattern},
checkers: make([]Checker, 0),
}
}
// RegexMatcher 正则匹配
func (e *Plugin) RegexMatcher(regex string) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: newRegexMatcher(regex),
checkers: make([]Checker, 0),
}
}
// ContainsMatcher 包含匹配
func (e *Plugin) ContainsMatcher(substring string) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: &ContainsMatcher{substring: substring},
checkers: make([]Checker, 0),
}
}
// FuncMatcher 自定义匹配函数
func (e *Plugin) FuncMatcher(fn func(operation string) bool, name ...string) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: &FuncMatcher{fn: fn},
checkers: make([]Checker, 0),
}
}
// CustomMatcher 使用自定义matcher
func (e *Plugin) CustomMatcher(matcher OperationMatcher) *RuleBuilder {
return &RuleBuilder{
plugin: e,
matcher: matcher,
checkers: make([]Checker, 0),
}
}
// ========== RuleBuilder 方法 ==========
// RequireLogin 需要登录
func (rb *RuleBuilder) RequireLogin() *RuleBuilder {
rb.checkers = append(rb.checkers, &LoginChecker{})
return rb
}
// RequirePermission 需要指定权限
func (rb *RuleBuilder) RequirePermission(permission string) *RuleBuilder {
rb.checkers = append(rb.checkers, &PermissionChecker{permission: permission})
return rb
}
// RequirePermissions 需要多个权限AND逻辑
func (rb *RuleBuilder) RequirePermissions(permissions ...string) *RuleBuilder {
rb.checkers = append(rb.checkers, &PermissionsAndChecker{permissions: permissions})
return rb
}
// RequireAnyPermission 需要任一权限OR逻辑
func (rb *RuleBuilder) RequireAnyPermission(permissions ...string) *RuleBuilder {
rb.checkers = append(rb.checkers, &PermissionsOrChecker{permissions: permissions})
return rb
}
// RequireRole 需要指定角色
func (rb *RuleBuilder) RequireRole(role string) *RuleBuilder {
rb.checkers = append(rb.checkers, &RoleChecker{role: role})
return rb
}
// RequireRoles 需要多个角色AND逻辑
func (rb *RuleBuilder) RequireRoles(roles ...string) *RuleBuilder {
rb.checkers = append(rb.checkers, &RolesAndChecker{roles: roles})
return rb
}
// RequireAnyRole 需要任一角色OR逻辑
func (rb *RuleBuilder) RequireAnyRole(roles ...string) *RuleBuilder {
rb.checkers = append(rb.checkers, &RolesOrChecker{roles: roles})
return rb
}
// CheckNotDisabled 检查账号未被封禁
func (rb *RuleBuilder) CheckNotDisabled() *RuleBuilder {
rb.checkers = append(rb.checkers, &DisableChecker{})
return rb
}
// CustomCheck 自定义检查
func (rb *RuleBuilder) CustomCheck(fn func(ctx context.Context, manager *core.Manager, loginID string) error) *RuleBuilder {
rb.checkers = append(rb.checkers, &CustomChecker{fn: fn})
return rb
}
// AddChecker 添加自定义checker
func (rb *RuleBuilder) AddChecker(checker Checker) *RuleBuilder {
rb.checkers = append(rb.checkers, checker)
return rb
}
// WithPriority 设置优先级
func (rb *RuleBuilder) WithPriority(priority int) *RuleBuilder {
rb.priority = priority
return rb
}
// Build 构建规则并添加到引擎
func (rb *RuleBuilder) Build() *Plugin {
rule := Rule{
Matcher: rb.matcher,
Checkers: rb.checkers,
Priority: rb.priority,
}
rb.plugin.addRule(rule)
return rb.plugin
}
// ========== Plugin便捷方法 ==========
// Skip 跳过指定operations
func (e *Plugin) Skip(operations ...string) *Plugin {
e.options.SkipOperations = append(e.options.SkipOperations, operations...)
return e
}
// DefaultRequireLogin 设置默认需要登录
func (e *Plugin) DefaultRequireLogin(require bool) *Plugin {
e.options.DefaultRequireLogin = require
return e
}
// SetErrorHandler 设置错误处理器
func (e *Plugin) SetErrorHandler(handler func(ctx context.Context, err error) error) *Plugin {
e.options.ErrorHandler = handler
return e
}
// AddRule 直接添加规则
func (e *Plugin) AddRule(rule Rule) *Plugin {
e.addRule(rule)
return e
}
// AddRules 批量添加规则
func (e *Plugin) AddRules(rules ...Rule) *Plugin {
for _, rule := range rules {
e.addRule(rule)
}
return e
}
// ========== 内部方法 ==========
func (e *Plugin) shouldSkip(operation string) bool {
for _, pattern := range e.options.SkipOperations {
if matchPattern(pattern, operation) {
return true
}
}
return false
}
func (e *Plugin) findRule(operation string) (Rule, bool) {
var matchedRules []Rule
// 找出所有匹配的规则
for _, rule := range e.rules {
if rule.Matcher.Match(operation) {
matchedRules = append(matchedRules, rule)
}
}
if len(matchedRules) == 0 {
return Rule{}, false
}
// 按优先级排序(降序)
sort.Slice(matchedRules, func(i, j int) bool {
return matchedRules[i].Priority > matchedRules[j].Priority
})
// 返回优先级最高的规则
return matchedRules[0], true
}
func (e *Plugin) addRule(rule Rule) {
e.rules = append(e.rules, rule)
}
func matchPattern(pattern, str string) bool {
// 完全匹配
if pattern == str {
return true
}
// 通配符 *
if pattern == "*" {
return true
}
// 前缀匹配 prefix*
if strings.HasSuffix(pattern, "*") {
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(str, prefix)
}
// 后缀匹配 *suffix
if strings.HasPrefix(pattern, "*") {
suffix := strings.TrimPrefix(pattern, "*")
return strings.HasSuffix(str, suffix)
}
// 通配符匹配
if strings.Contains(pattern, "*") || strings.Contains(pattern, "?") {
return wildcardMatch(pattern, str)
}
return false
}