新增oauth2-introspection插件

This commit is contained in:
Liujian
2025-05-14 17:31:09 +08:00
parent 2a0e090f44
commit 417bdff6d9
10 changed files with 492 additions and 14 deletions

View File

@@ -16,6 +16,7 @@ import (
"github.com/eolinker/apinto/drivers/plugins/gzip" "github.com/eolinker/apinto/drivers/plugins/gzip"
js_inject "github.com/eolinker/apinto/drivers/plugins/js-inject" js_inject "github.com/eolinker/apinto/drivers/plugins/js-inject"
"github.com/eolinker/apinto/drivers/plugins/oauth2" "github.com/eolinker/apinto/drivers/plugins/oauth2"
oauth2_introspection "github.com/eolinker/apinto/drivers/plugins/oauth2-introspection"
params_check "github.com/eolinker/apinto/drivers/plugins/params-check" params_check "github.com/eolinker/apinto/drivers/plugins/params-check"
params_check_v2 "github.com/eolinker/apinto/drivers/plugins/params-check-v2" params_check_v2 "github.com/eolinker/apinto/drivers/plugins/params-check-v2"
"github.com/eolinker/apinto/drivers/plugins/prometheus" "github.com/eolinker/apinto/drivers/plugins/prometheus"
@@ -119,6 +120,7 @@ func pluginRegister(extenderRegister eosc.IExtenderDriverRegister) {
// 鉴权插件 // 鉴权插件
oauth2.Register(extenderRegister) oauth2.Register(extenderRegister)
oauth2_introspection.Register(extenderRegister)
// ai相关插件 // ai相关插件
ai_prompt.Register(extenderRegister) ai_prompt.Register(extenderRegister)

View File

@@ -15,14 +15,6 @@ var (
var validPosition = []string{PositionHeader, PositionQuery, PositionBody} var validPosition = []string{PositionHeader, PositionQuery, PositionBody}
//func GetToken(ctx http_service.IHttpContext, tokenName string, position string) (string, bool) {
// token, has := getToken(ctx, tokenName, position)
// if has {
// ctx.SetLabel("token", token)
// }
// return token, has
//}
func GetToken(ctx http_service.IHttpContext, tokenName string, position string) (string, bool) { func GetToken(ctx http_service.IHttpContext, tokenName string, position string) (string, bool) {
switch position { switch position {
case PositionHeader: case PositionHeader:

View File

@@ -7,7 +7,7 @@ import (
"github.com/eolinker/eosc" "github.com/eolinker/eosc"
) )
//Create 创建驱动实例 // Create 创建驱动实例
func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) { func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
cfg, err := checkConfig(v) cfg, err := checkConfig(v)
if err != nil { if err != nil {
@@ -21,8 +21,11 @@ func Create(id, name string, v *Config, workers map[eosc.RequireId]eosc.IWorker)
WorkerBase: drivers.Worker(id, name), WorkerBase: drivers.Worker(id, name),
} }
err = a.set(cfg) err = a.set(cfg)
if err != nil {
return nil, err
}
return a, err return a, nil
} }
func checkConfig(v interface{}) (*Config, error) { func checkConfig(v interface{}) (*Config, error) {
@@ -33,9 +36,7 @@ func checkConfig(v interface{}) (*Config, error) {
if conf.Anonymous && len(conf.Auth) > 0 { if conf.Anonymous && len(conf.Auth) > 0 {
return nil, errors.New("it is anonymous app,auths should be empty") return nil, errors.New("it is anonymous app,auths should be empty")
} }
if conf.Anonymous && len(conf.Auth) > 0 {
return nil, errors.New("it is anonymous app,auths should be empty")
}
for _, a := range conf.Auth { for _, a := range conf.Auth {
err := application.CheckPosition(a.Position) err := application.CheckPosition(a.Position)
if err != nil { if err != nil {

View File

@@ -14,6 +14,7 @@ var _ IManager = (*Manager)(nil)
type IManager interface { type IManager interface {
Get(id string) (application.IAuth, bool) Get(id string) (application.IAuth, bool)
GetApp(appId string) (application.IApp, bool)
List() []application.IAuthUser List() []application.IAuthUser
ListByDriver(driver string) []application.IAuthUser ListByDriver(driver string) []application.IAuthUser
Set(app application.IApp, filters []application.IAuth, users map[string][]application.ITransformConfig) Set(app application.IApp, filters []application.IAuth, users map[string][]application.ITransformConfig)
@@ -26,6 +27,7 @@ type IManager interface {
type Manager struct { type Manager struct {
// filters map[string]application.IAuthUser // filters map[string]application.IAuthUser
eosc.Untyped[string, application.IAuth] eosc.Untyped[string, application.IAuth]
apps eosc.Untyped[string, application.IApp]
appManager *AppManager appManager *AppManager
driverAlias map[string]string driverAlias map[string]string
drivers []string drivers []string
@@ -33,6 +35,13 @@ type Manager struct {
app application.IApp app application.IApp
} }
func (m *Manager) GetApp(appId string) (application.IApp, bool) {
if !strings.HasSuffix(appId, "@app") {
appId = appId + "@app"
}
return m.apps.Get(appId)
}
func (m *Manager) AnonymousApp() application.IApp { func (m *Manager) AnonymousApp() application.IApp {
m.locker.RLock() m.locker.RLock()
app := m.app app := m.app
@@ -47,7 +56,7 @@ func (m *Manager) SetAnonymousApp(app application.IApp) {
} }
func NewManager(driverAlias map[string]string, drivers []string) IManager { func NewManager(driverAlias map[string]string, drivers []string) IManager {
return &Manager{Untyped: eosc.BuildUntyped[string, application.IAuth](), appManager: NewAppManager(), driverAlias: driverAlias, drivers: drivers} return &Manager{Untyped: eosc.BuildUntyped[string, application.IAuth](), appManager: NewAppManager(), driverAlias: driverAlias, drivers: drivers, apps: eosc.BuildUntyped[string, application.IApp]()}
} }
func (m *Manager) List() []application.IAuthUser { func (m *Manager) List() []application.IAuthUser {
@@ -121,6 +130,7 @@ func (m *Manager) Set(app application.IApp, filters []application.IAuth, users m
} }
} }
} }
m.apps.Set(app.Id(), app)
return return
} }
@@ -150,4 +160,5 @@ func (m *Manager) Del(appID string) {
} }
} }
m.appManager.DelByAppID(appID) m.appManager.DelByAppID(appID)
m.apps.Del(appID)
} }

View File

@@ -74,6 +74,15 @@ func anonymousAppHandler(ctx http_service.IHttpContext) (bool, error) {
func (a *App) auth(ctx http_service.IHttpContext) error { func (a *App) auth(ctx http_service.IHttpContext) error {
log.Debug("start auth...") log.Debug("start auth...")
appId := ctx.GetLabel("application_id")
if appId != "" {
app, has := appManager.GetApp(appId)
if has {
setLabels(ctx, app.Labels())
return nil
}
}
if appManager.Count() < 1 { if appManager.Count() < 1 {
if a.forceAuth { if a.forceAuth {
return fmt.Errorf("no app to auth") return fmt.Errorf("no app to auth")

View File

@@ -43,6 +43,16 @@ func TestExecutor(t *testing.T) {
type Context struct { type Context struct {
} }
func (c *Context) ProxyClone() http_service.IRequest {
//TODO implement me
panic("implement me")
}
func (c *Context) SetProxy(proxy http_service.IRequest) {
//TODO implement me
panic("implement me")
}
func (c *Context) RequestId() string { func (c *Context) RequestId() string {
//TODO implement me //TODO implement me
panic("implement me") panic("implement me")

View File

@@ -0,0 +1,64 @@
package oauth2_introspection
import (
"fmt"
"net/url"
)
const (
positionHeader = "header"
positionQuery = "query"
positionBody = "body"
)
const (
redisKeyPrefix = "apinto:oauth2-introspection"
)
type Config struct {
IntrospectionEndpoint string `json:"introspection_endpoint"`
IntrospectionSSLVerify bool `json:"introspection_ssl_verify" default:"true"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
TokenHeader string `json:"token_header"`
Scopes []string `json:"scopes"`
TTL int `json:"ttl" default:"600"`
CustomClaimsForward []string `json:"custom_claims_forward"`
ConsumerBy string `json:"consumer_by"`
AllowAnonymous bool `json:"allow_anonymous" default:"false"`
HideCredential bool `json:"hide_credential" default:"false"`
}
func Check(conf *Config) error {
if conf.IntrospectionEndpoint == "" {
return fmt.Errorf("introspection_endpoint is required")
}
u, err := url.Parse(conf.IntrospectionEndpoint)
if err != nil {
return fmt.Errorf("introspection_endpoint is invalid: %w", err)
}
if u.Scheme == "" || u.Host == "" {
return fmt.Errorf("introspection_endpoint is invalid: %s", conf.IntrospectionEndpoint)
}
if conf.ClientID == "" {
return fmt.Errorf("client_id is required")
}
if conf.ClientSecret == "" {
return fmt.Errorf("client_secret is required")
}
if conf.TokenHeader == "" {
conf.TokenHeader = "Authorization"
}
if conf.ConsumerBy == "" {
conf.ConsumerBy = "client_id"
}
if conf.TTL <= 0 {
conf.TTL = 600
}
return nil
}

View File

@@ -0,0 +1,190 @@
package oauth2_introspection
import (
"crypto/tls"
"encoding/json"
"fmt"
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/apinto/resources"
scope_manager "github.com/eolinker/apinto/scope-manager"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/eocontext"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"net/http"
"sync"
"time"
)
var _ http_service.HttpFilter = (*executor)(nil)
var _ eocontext.IFilter = (*executor)(nil)
var _ eosc.IWorker = (*executor)(nil)
type executor struct {
drivers.WorkerBase
client http.Client
endpoint string
clientId string
clientSecret string
tokenName string
scopes map[string]struct{}
ttl time.Duration
claims []string
consumeBy string
hideCredential bool
allowAnonymous bool
once sync.Once
cache scope_manager.IProxyOutput[resources.ICache]
}
func (e *executor) Start() error {
return nil
}
func (e *executor) Reset(conf interface{}, workers map[eosc.RequireId]eosc.IWorker) error {
return nil
}
func (e *executor) reset(conf *Config) error {
client := http.Client{
Timeout: 5 * time.Second,
}
if !conf.IntrospectionSSLVerify {
client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
}
e.hideCredential = conf.HideCredential
e.client = client
e.endpoint = conf.IntrospectionEndpoint
e.clientId = conf.ClientID
e.clientSecret = conf.ClientSecret
e.tokenName = conf.TokenHeader
e.scopes = make(map[string]struct{})
for _, scope := range conf.Scopes {
e.scopes[scope] = struct{}{}
}
e.ttl = time.Duration(conf.TTL) * time.Second
e.claims = conf.CustomClaimsForward
e.consumeBy = conf.ConsumerBy
e.allowAnonymous = conf.AllowAnonymous
return nil
}
func (e *executor) Stop() error {
e.Destroy()
return nil
}
func (e *executor) CheckSkill(skill string) bool {
return http_service.FilterSkillName == skill
}
func (e *executor) DoFilter(ctx eocontext.EoContext, next eocontext.IChain) (err error) {
return http_service.DoHttpFilter(e, ctx, next)
}
func (e *executor) Destroy() {
e.client.CloseIdleConnections()
return
}
func (e *executor) DoHttpFilter(ctx http_service.IHttpContext, next eocontext.IChain) (err error) {
token := retrieveAccessToken(ctx, positionHeader, e.tokenName)
if token == "" {
ctx.Response().SetBody([]byte("empty token"))
ctx.Response().SetStatus(http.StatusUnauthorized, "empty token")
return fmt.Errorf("empty token")
}
e.once.Do(func() {
e.cache = scope_manager.Auto[resources.ICache]("", "redis")
})
ctx.SetLabel("token", token)
var introspectionInfo *eosc.Base[IntrospectionResponseBody]
var cache resources.ICache
if len(e.cache.List()) > 0 {
cache = e.cache.List()[0]
}
if cache != nil {
d, err := cache.Get(ctx.Context(), fmt.Sprintf("%s:%s", redisKeyPrefix, token)).Result()
if err == nil {
var t eosc.Base[IntrospectionResponseBody]
err = json.Unmarshal([]byte(d), &t)
if err == nil {
introspectionInfo = &t
}
}
}
if (introspectionInfo != nil && !checkActive(introspectionInfo.Config)) || introspectionInfo == nil {
// 当缓存信息不存在或者缓存信息过期时,重新发起请求
introspectionInfo, err = doIntrospectAccessToken(&e.client, e.endpoint, e.clientId, e.clientSecret, token)
if err != nil {
errInfo := fmt.Sprintf("do introspect access token error: %s", err.Error())
ctx.Response().SetBody([]byte(errInfo))
ctx.Response().SetStatus(http.StatusInternalServerError, "Internal Server Error")
return fmt.Errorf(errInfo)
}
}
err = verifyIntrospection(introspectionInfo.Config, e.clientId, e.scopes)
if err != nil {
// 校验失败
errInfo := fmt.Sprintf("verify introspection error: %s", err.Error())
ctx.Response().SetBody([]byte(errInfo))
ctx.Response().SetStatus(http.StatusUnauthorized, "Unauthorized")
return fmt.Errorf(errInfo)
}
err = setAppLabel(ctx, introspectionInfo.Config, e.consumeBy, e.allowAnonymous)
if err != nil {
errInfo := fmt.Sprintf("set app label error: %s", err.Error())
ctx.Response().SetBody([]byte(errInfo))
ctx.Response().SetStatus(http.StatusUnauthorized, "Unauthorized")
return fmt.Errorf(errInfo)
}
if cache != nil {
d, err := json.Marshal(introspectionInfo)
if err == nil {
_, err = cache.SetNX(ctx.Context(), fmt.Sprintf("%s:%s", redisKeyPrefix, token), d, e.ttl).Result()
if err != nil {
errInfo := fmt.Sprintf("set cache error: %s", err.Error())
ctx.Response().SetBody([]byte(errInfo))
ctx.Response().SetStatus(http.StatusInternalServerError, "Internal Server Error")
return fmt.Errorf(errInfo)
}
}
}
if e.hideCredential {
ctx.Proxy().Header().DelHeader(e.tokenName)
}
ctx.Proxy().Header().SetHeader("X-Credential-Scope", introspectionInfo.Config.Scope)
ctx.Proxy().Header().SetHeader("X-Credential-Client-ID", introspectionInfo.Config.ClientId)
ctx.Proxy().Header().SetHeader("X-Credential-Token-Type", "Bearer")
ctx.Proxy().Header().SetHeader("X-Credential-Exp", fmt.Sprintf("%d", introspectionInfo.Config.Exp))
ctx.Proxy().Header().SetHeader("X-Credential-Iat", fmt.Sprintf("%d", introspectionInfo.Config.Iat))
ctx.Proxy().Header().SetHeader("X-Credential-Nbf", fmt.Sprintf("%d", introspectionInfo.Config.Nbf))
ctx.Proxy().Header().SetHeader("X-Credential-Sub", introspectionInfo.Config.Sub)
ctx.Proxy().Header().SetHeader("X-Credential-Aud", introspectionInfo.Config.Aud)
ctx.Proxy().Header().SetHeader("X-Credential-Iss", introspectionInfo.Config.Iss)
ctx.Proxy().Header().SetHeader("X-Credential-Jti", introspectionInfo.Config.Jti)
for _, v := range e.claims {
a, ok := introspectionInfo.Append[v]
if !ok {
continue
}
vv, ok := a.(string)
if !ok {
continue
}
ctx.Proxy().Header().SetHeader(fmt.Sprintf("X-Credential-%s", v), vv)
}
if next != nil {
return next.DoChain(ctx)
}
return nil
}

View File

@@ -0,0 +1,44 @@
package oauth2_introspection
import (
"github.com/eolinker/apinto/drivers"
"github.com/eolinker/apinto/drivers/app/manager"
"github.com/eolinker/eosc"
"github.com/eolinker/eosc/common/bean"
"sync"
)
const (
Name = "oauth2-introspection"
)
var (
ones sync.Once
appManager manager.IManager
)
func Register(register eosc.IExtenderDriverRegister) {
register.RegisterExtenderDriver(Name, NewFactory())
}
func NewFactory() eosc.IExtenderDriverFactory {
ones.Do(func() {
bean.Autowired(&appManager)
})
return drivers.NewFactory[Config](Create)
}
func Create(id, name string, conf *Config, workers map[eosc.RequireId]eosc.IWorker) (eosc.IWorker, error) {
err := Check(conf)
if err != nil {
return nil, err
}
e := &executor{
WorkerBase: drivers.Worker(id, name),
}
err = e.reset(conf)
if err != nil {
return nil, err
}
return e, nil
}

View File

@@ -0,0 +1,155 @@
package oauth2_introspection
import (
"encoding/json"
"fmt"
"github.com/eolinker/eosc"
http_service "github.com/eolinker/eosc/eocontext/http-context"
"io"
"net/http"
"net/url"
"strings"
"time"
)
type IntrospectionResponseBody struct {
Active bool `json:"active"`
ClientId string `json:"client_id"`
Username string `json:"username"`
Scope string `json:"scope"`
Sub string `json:"sub"`
Aud string `json:"aud"`
Iss string `json:"iss"`
Exp int64 `json:"exp"`
Iat int64 `json:"iat"`
Nbf int64 `json:"nbf"`
Jti string `json:"jti"`
}
func setAppLabel(ctx http_service.IHttpContext, t *IntrospectionResponseBody, consumerBy string, allowAnonymous bool) error {
consumer := t.ClientId
switch consumerBy {
case "client_id":
case "username":
consumer = t.Username
default:
return fmt.Errorf("invalid consumer_by")
}
a, has := appManager.GetApp(consumer)
if !has {
if !allowAnonymous {
return fmt.Errorf("consumer(%s) not found", consumer)
}
a = appManager.AnonymousApp()
if a == nil {
return fmt.Errorf("anonymous app not found")
}
ctx.Proxy().Header().SetHeader("X-Consumer-Anonymous", "true")
}
ctx.SetLabel("application_id", a.Id())
ctx.SetLabel("application_name", a.Name())
ctx.Proxy().Header().SetHeader("X-Consumer-ID", a.Id())
ctx.Proxy().Header().SetHeader("X-Consumer-Username", a.Name())
return nil
}
func verifyIntrospection(t *IntrospectionResponseBody, clientId string, scopes map[string]struct{}) error {
if t.Active != true {
return fmt.Errorf("token is not active")
}
if t.ClientId != clientId {
return fmt.Errorf("invalid client_id")
}
now := time.Now()
if t.Exp < now.Unix() {
return fmt.Errorf("token is expired")
}
if t.Iat > now.Unix() {
return fmt.Errorf("token is not yet active")
}
if len(scopes) > 0 {
if _, ok := scopes[t.Scope]; !ok {
return fmt.Errorf("invalid scope")
}
}
return nil
}
func checkActive(t *IntrospectionResponseBody) bool {
if t.Active != true {
return false
}
now := time.Now()
if t.Exp < now.Unix() {
return false
}
if t.Iat > now.Unix() {
return false
}
return true
}
func doIntrospectAccessToken(client *http.Client, endpoint string, clientId string, clientSecret string, token string) (*eosc.Base[IntrospectionResponseBody], error) {
body := url.Values{}
body.Set("token", token)
body.Set("client_id", clientId)
body.Set("client_secret", clientSecret)
req, err := http.NewRequest(http.MethodPost, endpoint, strings.NewReader(body.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(data))
}
t := new(eosc.Base[IntrospectionResponseBody])
err = json.Unmarshal(data, t)
if err != nil {
return nil, err
}
return t, nil
}
func retrieveAccessToken(ctx http_service.IHttpContext, tokenPosition string, tokenName string) string {
token := ""
switch tokenPosition {
case positionHeader:
token = ctx.Request().Header().GetHeader(tokenName)
return strings.TrimPrefix(token, "Bearer ")
case positionQuery:
token = ctx.Request().URI().GetQuery(tokenName)
case positionBody:
if strings.Contains(ctx.Request().ContentType(), "application/x-www-form-urlencoded") || strings.Contains(ctx.Request().ContentType(), "multipart/form-data") {
token = ctx.Request().Body().GetForm(tokenName)
} else if strings.Contains(ctx.Request().ContentType(), "application/json") {
body, _ := ctx.Request().Body().RawBody()
if string(body) != "" {
m := make(map[string]interface{})
err := json.Unmarshal(body, &m)
if err == nil {
if v, ok := m[tokenName]; ok {
token = fmt.Sprintf("%v", v)
}
} else {
return ""
}
}
}
default:
return ""
}
return strings.TrimPrefix(token, "Bearer ")
}