refactor(security): 优化刷新令牌反序列化逻辑

- 引入 utils.ToBytes 方法统一数据转换
- 使用 UnmarshalBinary 替代类型断言解析 RefreshTokenInfo
- 增强错误处理和类型兼容性
- 更新访问令牌映射存储逻辑
-修复刷新令牌管理器初始化参数传递问题- 新增 ToBytes 工具函数支持节多种类型转字切片
This commit is contained in:
zhuyanwei
2025-11-11 10:30:20 +08:00
parent ccbb999eb9
commit 57298a107f
3 changed files with 45 additions and 5 deletions

View File

@@ -182,7 +182,7 @@ func NewNonceManager(storage Storage, prefix string, ttl ...int64) *NonceManager
// NewRefreshTokenManager Creates a new refresh token manager | 创建新的刷新令牌管理器
func NewRefreshTokenManager(storage Storage, prefix string, cfg *Config) *RefreshTokenManager {
return security.NewRefreshTokenManager(storage, prefix, cfg)
return security.NewRefreshTokenManager(storage, prefix, manager.TokenKeyPrefix, cfg)
}
// NewOAuth2Server Creates a new OAuth2 server | 创建新的OAuth2服务器

View File

@@ -10,6 +10,7 @@ import (
"github.com/click33/sa-token-go/core/adapter"
"github.com/click33/sa-token-go/core/config"
"github.com/click33/sa-token-go/core/token"
"github.com/click33/sa-token-go/core/utils"
)
// Refresh Token Implementation
@@ -140,15 +141,25 @@ func (rtm *RefreshTokenManager) RefreshAccessToken(refreshToken string) (*Refres
return nil, ErrInvalidRefreshToken
}
// Get refresh token info | 获取刷新令牌信息
key := rtm.getRefreshKey(refreshToken)
// Get refresh token info | 获取刷新令牌信息
data, err := rtm.storage.Get(key)
if err != nil || data == nil {
return nil, ErrInvalidRefreshToken
}
oldInfo, ok := data.(*RefreshTokenInfo)
if !ok {
// Convert to RefreshTokenInfo | 转换为 RefreshTokenInfo
dataBytes, err := utils.ToBytes(data)
if err != nil {
return nil, ErrInvalidRefreshData
}
// Unmarshal data | 反序列化数据
oldInfo := &RefreshTokenInfo{}
err = oldInfo.UnmarshalBinary(dataBytes)
if err != nil {
return nil, ErrInvalidRefreshData
}
@@ -164,8 +175,15 @@ func (rtm *RefreshTokenManager) RefreshAccessToken(refreshToken string) (*Refres
return nil, fmt.Errorf("failed to generate new access token: %w", err)
}
// Update access token info | 更新访问令牌信息
oldInfo.AccessToken = newAccessToken
// Save token-loginID mapping (符合 Java sa-token 设计) | 保存 Token-LoginID 映射
tokenKey := rtm.getTokenKey(newAccessToken)
if err := rtm.storage.Set(tokenKey, oldInfo.LoginID, rtm.accessTTL); err != nil {
return nil, fmt.Errorf("failed to save token: %w", err)
}
// Update storage | 更新存储
if err := rtm.storage.Set(key, oldInfo, rtm.refreshTTL); err != nil {
return nil, fmt.Errorf("failed to update refresh token: %w", err)
@@ -196,8 +214,14 @@ func (rtm *RefreshTokenManager) GetRefreshTokenInfo(refreshToken string) (*Refre
return nil, ErrInvalidRefreshToken
}
info, ok := data.(*RefreshTokenInfo)
if !ok {
dataBytes, err := utils.ToBytes(data)
if err != nil {
return nil, ErrInvalidRefreshData
}
info := &RefreshTokenInfo{}
err = info.UnmarshalBinary(dataBytes)
if err != nil {
return nil, ErrInvalidRefreshData
}

View File

@@ -478,6 +478,22 @@ func ToBool(v any) (bool, error) {
}
}
// ToBytes Converts any to bytes | 将any转换为字节
func ToBytes(value any) ([]byte, error) {
switch v := value.(type) {
case string:
return []byte(v), nil
case []byte:
return v, nil
case byte:
return []byte{v}, nil
case rune:
return []byte(string(v)), nil
default:
return nil, fmt.Errorf("unsupported type: %T", value)
}
}
// ============ Hash & Encoding | 哈希和编码 ============
// SHA256Hash Generates SHA256 hash of string | 生成字符串的SHA256哈希