mirror of
https://github.com/xslasd/x-oidc.git
synced 2025-09-27 04:16:00 +08:00
187 lines
5.7 KiB
Go
187 lines
5.7 KiB
Go
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"github.com/xslasd/x-oidc/constant"
|
|
"github.com/xslasd/x-oidc/ecode"
|
|
"github.com/xslasd/x-oidc/model"
|
|
"github.com/xslasd/x-oidc/storage"
|
|
"github.com/xslasd/x-oidc/util"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token,omitempty" schema:"access_token,omitempty"`
|
|
TokenType string `json:"token_type,omitempty" schema:"token_type,omitempty"`
|
|
RefreshToken string `json:"refresh_token,omitempty" schema:"refresh_token,omitempty"`
|
|
ExpiresIn uint64 `json:"expires_in,omitempty" schema:"expires_in,omitempty"`
|
|
IDToken string `json:"id_token,omitempty" schema:"id_token,omitempty"`
|
|
State string `json:"state,omitempty" schema:"state,omitempty"`
|
|
}
|
|
|
|
func (o *OpenIDProvider) CreateAccessTokenAndIDToken(ctx context.Context, req *storage.AuthRequest, client storage.IClient, fn func() (*storage.TokenModel, error)) (*model.AccessTokenRes, error) {
|
|
accessToken, newRefreshToken, validity, err := o.CreateAccessToken(ctx, req, client, fn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
authTime := time.Now().UTC()
|
|
var idToken string
|
|
|
|
if req.UserID != "" {
|
|
idToken, err = o.CreateIDToken(ctx, req, client, authTime, func(claims *model.IDTokenClaims) error {
|
|
atHash, err := o.cfg.Crypto.ClaimHash(accessToken)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
claims.AccessTokenHash = atHash
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return &model.AccessTokenRes{
|
|
AccessToken: accessToken,
|
|
TokenType: client.AccessTokenTransferType(),
|
|
RefreshToken: newRefreshToken,
|
|
ExpiresIn: uint64(validity),
|
|
IDToken: idToken,
|
|
State: req.State,
|
|
}, nil
|
|
}
|
|
|
|
func (o *OpenIDProvider) CreateAccessToken(ctx context.Context, req *storage.AuthRequest, client storage.IClient, fn func() (*storage.TokenModel, error)) (accessToken, refreshToken string, validity time.Duration, err error) {
|
|
tokenModel, err := fn()
|
|
if err != nil {
|
|
return "", "", 0, err
|
|
}
|
|
validity = tokenModel.AccessTokenExpiration.Add(client.ClockSkew()).Sub(time.Now().UTC())
|
|
if client.AccessTokenTransferType() == constant.AccessTokenTransferTypeJWT {
|
|
accessToken, err = o.CreateJWTAccessToken(ctx, req, client, func(claims *model.AccessTokenClaims) error {
|
|
claims.TokenClaims.JWTID = tokenModel.TokenID
|
|
claims.TokenClaims.Expiration = tokenModel.AccessTokenExpiration.Unix()
|
|
return err
|
|
})
|
|
return accessToken, tokenModel.RefreshToken, validity, err
|
|
}
|
|
accessToken, err = o.cfg.Crypto.Encrypt(tokenModel.TokenID + ":" + req.UserID)
|
|
return accessToken, tokenModel.RefreshToken, validity, err
|
|
}
|
|
|
|
func (o *OpenIDProvider) CreateJWTAccessToken(ctx context.Context, req *storage.AuthRequest, client storage.IClient, fn func(claims *model.AccessTokenClaims) error) (string, error) {
|
|
now := time.Now().UTC().Add(-client.ClockSkew())
|
|
scopes := client.RestrictAdditionalAccessTokenScopes()(req.Scopes)
|
|
claims := &model.AccessTokenClaims{
|
|
TokenClaims: model.TokenClaims{
|
|
Issuer: o.cfg.Issuer,
|
|
Subject: req.UserID,
|
|
Audience: req.Audience,
|
|
IssuedAt: now.Unix(),
|
|
NotBefore: now.Unix(),
|
|
},
|
|
Scopes: strings.Join(scopes, " "),
|
|
}
|
|
err := fn(claims)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
accessToken, err := o.cfg.Crypto.GenerateJWT(claims)
|
|
return accessToken, err
|
|
}
|
|
func (o *OpenIDProvider) CreateIDToken(ctx context.Context, req *storage.AuthRequest, client storage.IClient, authTime time.Time, fn func(claims *model.IDTokenClaims) error) (string, error) {
|
|
expiration := authTime.Add(client.IDTokenLifetime())
|
|
now := time.Now().UTC().Add(-client.ClockSkew())
|
|
scopes := client.RestrictAdditionalIdTokenScopes()(req.Scopes)
|
|
if !client.IDTokenUserinfoClaimsAssertion() {
|
|
scopes = util.RemoveUserinfoScopes(scopes)
|
|
}
|
|
claims := &model.IDTokenClaims{
|
|
TokenClaims: model.TokenClaims{
|
|
Issuer: o.cfg.Issuer,
|
|
Subject: req.UserID,
|
|
Audience: req.Audience,
|
|
Expiration: expiration.Unix(),
|
|
IssuedAt: now.Unix(),
|
|
NotBefore: now.Unix(),
|
|
Nonce: req.Nonce,
|
|
ClientID: client.GetClientID(),
|
|
},
|
|
AuthTime: authTime.Unix(),
|
|
}
|
|
err := fn(claims)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if len(scopes) > 0 {
|
|
userInfo, err := o.cfg.Storage.SetUserinfoFromScopes(ctx, *req, client, scopes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
claims.SetUserInfo(userInfo)
|
|
}
|
|
idToken, err := o.cfg.Crypto.GenerateJWT(claims)
|
|
return idToken, err
|
|
}
|
|
|
|
func (o *OpenIDProvider) VerifyAccessToken(ctx context.Context, tokenStr string) (*model.AccessTokenClaims, error) {
|
|
parts := strings.Split(tokenStr, ".")
|
|
switch len(parts) {
|
|
case 1:
|
|
tokenIDSubject, err := o.cfg.Crypto.Decrypt(tokenStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
splitToken := strings.Split(tokenIDSubject, ":")
|
|
if len(splitToken) != 2 {
|
|
return nil, ecode.AccessTokenInvalid
|
|
}
|
|
return &model.AccessTokenClaims{
|
|
TokenClaims: model.TokenClaims{JWTID: splitToken[0], Subject: splitToken[1]},
|
|
}, nil
|
|
case 3:
|
|
res := new(model.AccessTokenClaims)
|
|
err := o.cfg.Crypto.ParseJWT(tokenStr, res)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = res.CheckIssuer(o.cfg.Issuer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = res.CheckExpiration()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = res.CheckIssuedAt()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return res, nil
|
|
default:
|
|
return nil, ecode.AccessTokenInvalid
|
|
}
|
|
}
|
|
|
|
func (o *OpenIDProvider) VerifyIDToken(ctx context.Context, tokenStr string) (*model.IDTokenClaims, error) {
|
|
res := new(model.IDTokenClaims)
|
|
err := o.cfg.Crypto.ParseJWT(tokenStr, res)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = res.CheckIssuer(o.cfg.Issuer)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = res.CheckExpiration()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = res.CheckIssuedAt()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return res, nil
|
|
}
|