mirror of
https://github.com/weloe/token-go.git
synced 2025-10-06 16:07:18 +08:00
feat: support SSO
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# Token-Go
|
||||
|
||||
This library focuses on solving login authentication problems, such as: login, multi-account login, shared token, logout, kickout ...
|
||||
This library focuses on solving login authentication problems, such as: login, multi-account login, shared token, logout, kickout, banned, SSO ...
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -189,6 +189,9 @@ func CheckAuth(w http.ResponseWriter, req *http.Request) {
|
||||
fmt.Fprintf(w, "you have authorization")
|
||||
}
|
||||
```
|
||||
## SSO
|
||||
|
||||
|
||||
|
||||
## Api
|
||||
|
||||
|
30
config/sign.go
Normal file
30
config/sign.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package config
|
||||
|
||||
// SignConfig sign config
|
||||
type SignConfig struct {
|
||||
SecretKey string
|
||||
TimeStampDisparity int64
|
||||
IsCheckNonce bool
|
||||
}
|
||||
|
||||
func NewSignConfig(options *SignOptions) (*SignConfig, error) {
|
||||
if options == nil {
|
||||
options = &SignOptions{}
|
||||
}
|
||||
if options.TimeStampDisparity == 0 {
|
||||
options.TimeStampDisparity = 1000 * 60 * 15
|
||||
}
|
||||
return &SignConfig{
|
||||
SecretKey: options.SecretKey,
|
||||
TimeStampDisparity: options.TimeStampDisparity,
|
||||
IsCheckNonce: options.IsCheckNonce,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SignConfig) GetSaveNonceExpire() int64 {
|
||||
if s.TimeStampDisparity >= 0 {
|
||||
return s.TimeStampDisparity / 1000
|
||||
} else {
|
||||
return 60 * 60 * 24
|
||||
}
|
||||
}
|
168
config/sso.go
Normal file
168
config/sso.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/weloe/token-go/ctx"
|
||||
model2 "github.com/weloe/token-go/model"
|
||||
"github.com/weloe/token-go/util"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func DefaultSsoConfig(serverUrl string, notLoginView func() interface{},
|
||||
doLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error),
|
||||
ticketResultHandle func(o1 string, s string) (interface{}, error),
|
||||
sendHttp func(url string) (string, error)) *SsoConfig {
|
||||
if notLoginView == nil {
|
||||
notLoginView = func() interface{} {
|
||||
return "not logged in to the SSO-Server"
|
||||
}
|
||||
}
|
||||
return &SsoConfig{
|
||||
Mode: "",
|
||||
TicketTimeout: 60 * 5,
|
||||
AllowUrl: "*",
|
||||
IsSlo: true,
|
||||
IsHttp: false,
|
||||
Client: "",
|
||||
AuthUrl: "/sso/auth",
|
||||
CheckTicketUrl: "/sso/checkTicket",
|
||||
GetDataUrl: "/sso/getData",
|
||||
UserInfoUrl: "/sso/userInfo",
|
||||
SloUrl: "/sso/signOut",
|
||||
SsoLogoutCall: "",
|
||||
ServerUrl: serverUrl,
|
||||
NotLoginView: notLoginView,
|
||||
DoLoginHandle: doLoginHandle,
|
||||
TicketResultHandle: ticketResultHandle,
|
||||
SendHttp: sendHttp,
|
||||
}
|
||||
}
|
||||
|
||||
func NewSsoConfig(options *SsoOptions) (*SsoConfig, error) {
|
||||
if options == nil {
|
||||
options = &SsoOptions{}
|
||||
}
|
||||
if options.TicketTimeout == 0 {
|
||||
options.TicketTimeout = 60 * 5
|
||||
}
|
||||
if options.AllowUrl == "" {
|
||||
options.AllowUrl = "*"
|
||||
}
|
||||
if options.AuthUrl == "" {
|
||||
options.AuthUrl = "/sso/auth"
|
||||
}
|
||||
if options.CheckTicketUrl == "" {
|
||||
options.CheckTicketUrl = "/sso/checkTicket"
|
||||
}
|
||||
if options.GetDataUrl == "" {
|
||||
options.GetDataUrl = "/sso/getData"
|
||||
}
|
||||
if options.UserInfoUrl == "" {
|
||||
options.UserInfoUrl = "/sso/userInfo"
|
||||
}
|
||||
if options.SloUrl == "" {
|
||||
options.SloUrl = "/sso/signout"
|
||||
}
|
||||
|
||||
if options.NotLoginView == nil {
|
||||
options.NotLoginView = func() interface{} {
|
||||
return "not logged in to the SSO-Server"
|
||||
}
|
||||
}
|
||||
if options.DoLoginHandle == nil {
|
||||
options.DoLoginHandle = func(name string, pwd string, ctx ctx.Context) (interface{}, error) {
|
||||
return model2.Error(), errors.New("SsoConfig.DoLoginHandle is nil")
|
||||
}
|
||||
}
|
||||
if options.IsHttp && options.SendHttp == nil {
|
||||
return nil, errors.New("please config SSO SentHttp")
|
||||
}
|
||||
|
||||
return &SsoConfig{
|
||||
Mode: options.Mode,
|
||||
TicketTimeout: options.TicketTimeout,
|
||||
AllowUrl: options.AllowUrl,
|
||||
IsSlo: options.IsSlo,
|
||||
IsHttp: options.IsHttp,
|
||||
Client: options.Client,
|
||||
AuthUrl: options.AuthUrl,
|
||||
CheckTicketUrl: options.CheckTicketUrl,
|
||||
GetDataUrl: options.GetDataUrl,
|
||||
UserInfoUrl: options.UserInfoUrl,
|
||||
SloUrl: options.SloUrl,
|
||||
SsoLogoutCall: options.SsoLogoutCall,
|
||||
ServerUrl: options.ServerUrl,
|
||||
NotLoginView: options.NotLoginView,
|
||||
DoLoginHandle: options.DoLoginHandle,
|
||||
TicketResultHandle: options.TicketResultHandle,
|
||||
SendHttp: options.SendHttp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type SsoConfig struct {
|
||||
|
||||
// Mode sso mode
|
||||
Mode string
|
||||
// TicketTimeout ticket timeout
|
||||
TicketTimeout int64
|
||||
// AllowUrl All allowed authorization callback addresses, separated by ','
|
||||
AllowUrl string
|
||||
IsSlo bool
|
||||
IsHttp bool
|
||||
|
||||
// SSO-Client current client name
|
||||
Client string
|
||||
// SSO-Server auth url
|
||||
AuthUrl string
|
||||
// SSO-Server check ticket url
|
||||
CheckTicketUrl string
|
||||
GetDataUrl string
|
||||
UserInfoUrl string
|
||||
SloUrl string
|
||||
SsoLogoutCall string
|
||||
|
||||
ServerUrl string
|
||||
|
||||
/**
|
||||
sso callback func
|
||||
*/
|
||||
// NotLoginView
|
||||
NotLoginView func() interface{}
|
||||
|
||||
// DoLoginHandle login func
|
||||
DoLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error)
|
||||
|
||||
// TicketResultHandle called each time the result of the validation ticket is obtained from the SSO-Server
|
||||
TicketResultHandle func(loginId string, back string) (interface{}, error)
|
||||
|
||||
// SendHttp sent http
|
||||
SendHttp func(url string) (string, error)
|
||||
}
|
||||
|
||||
// SpliceAuthUrl return Server-side single sign-on authorization address
|
||||
func (c *SsoConfig) SpliceAuthUrl() string {
|
||||
return util.SpliceUrl(c.ServerUrl, c.AuthUrl)
|
||||
}
|
||||
|
||||
// SpliceCheckTicketUrl return ticket verification address on the server side
|
||||
func (c *SsoConfig) SpliceCheckTicketUrl() string {
|
||||
return util.SpliceUrl(c.ServerUrl, c.CheckTicketUrl)
|
||||
}
|
||||
|
||||
func (c *SsoConfig) SpliceGetDataUrl() string {
|
||||
return util.SpliceUrl(c.ServerUrl, c.GetDataUrl)
|
||||
}
|
||||
|
||||
func (c *SsoConfig) SpliceUserInfoUrl() string {
|
||||
return util.SpliceUrl(c.ServerUrl, c.UserInfoUrl)
|
||||
}
|
||||
|
||||
func (c *SsoConfig) SpliceSloUrl() string {
|
||||
return util.SpliceUrl(c.ServerUrl, c.SloUrl)
|
||||
}
|
||||
|
||||
// SetAllow set allow callback url
|
||||
func (c *SsoConfig) SetAllow(urls ...string) *SsoConfig {
|
||||
c.AllowUrl = strings.Join(urls, ",")
|
||||
return c
|
||||
}
|
52
config/sso_options.go
Normal file
52
config/sso_options.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package config
|
||||
|
||||
import "github.com/weloe/token-go/ctx"
|
||||
|
||||
// SsoOptions new SsoConfig options.
|
||||
type SsoOptions struct {
|
||||
CookieDomain string
|
||||
|
||||
// Mode sso mode
|
||||
Mode string
|
||||
// TicketTimeout ticket timeout
|
||||
TicketTimeout int64
|
||||
// AllowUrl All allowed authorization callback addresses, separated by ','
|
||||
AllowUrl string
|
||||
IsSlo bool
|
||||
IsHttp bool
|
||||
|
||||
// SSO-Client current client name
|
||||
Client string
|
||||
// SSO-Server auth url
|
||||
AuthUrl string
|
||||
// SSO-Server check ticket url
|
||||
CheckTicketUrl string
|
||||
GetDataUrl string
|
||||
UserInfoUrl string
|
||||
SloUrl string
|
||||
SsoLogoutCall string
|
||||
|
||||
ServerUrl string
|
||||
|
||||
/**
|
||||
sso callback func
|
||||
*/
|
||||
// NotLoginView
|
||||
NotLoginView func() interface{}
|
||||
|
||||
// DoLoginHandle login func
|
||||
DoLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error)
|
||||
|
||||
// TicketResultHandle called each time the result of the validation ticket is obtained from the SSO-Server
|
||||
TicketResultHandle func(loginId string, back string) (interface{}, error)
|
||||
|
||||
// SendHttp sent http
|
||||
SendHttp func(url string) (string, error)
|
||||
}
|
||||
|
||||
// SignOptions SignConfig options
|
||||
type SignOptions struct {
|
||||
SecretKey string
|
||||
TimeStampDisparity int64
|
||||
IsCheckNonce bool
|
||||
}
|
11
constant/sso.go
Normal file
11
constant/sso.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
SLO_CALLBACK_SET_KEY = "SLO_CALLBACK_SET_KEY_"
|
||||
|
||||
SELF = "self"
|
||||
|
||||
MODE_SIMPLE = "simple"
|
||||
|
||||
NOT_HANDLE = "{\"msg\": \"not handle\"}"
|
||||
)
|
39
model/http_result.go
Normal file
39
model/http_result.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package model
|
||||
|
||||
const (
|
||||
SUCCESS = 1
|
||||
ERROR = 0
|
||||
)
|
||||
|
||||
// Result wrap the http request result.
|
||||
type Result struct {
|
||||
Code int
|
||||
Msg string
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
func Ok() *Result {
|
||||
return &Result{
|
||||
Code: SUCCESS,
|
||||
Msg: "success",
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func Error() *Result {
|
||||
return &Result{
|
||||
Code: -1,
|
||||
Msg: "error",
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Result) SetData(data interface{}) *Result {
|
||||
r.Data = data
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Result) SetMsg(msg string) *Result {
|
||||
r.Msg = msg
|
||||
return r
|
||||
}
|
34
sso/api.go
Normal file
34
sso/api.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package sso
|
||||
|
||||
// ApiName sso api name, used to dispatcher request.
|
||||
type ApiName struct {
|
||||
// sso-server auth url
|
||||
SsoAuth string
|
||||
// sso-server rest api login url
|
||||
SsoDoLogin string
|
||||
// sso-server check ticket url
|
||||
SsoCheckTicket string
|
||||
// sso-server get user info url
|
||||
SsoUserInfo string
|
||||
// sso-server single logout url
|
||||
SsoSignout string
|
||||
// sso-client login url
|
||||
SsoLogin string
|
||||
// sso-client single logout url
|
||||
SsoLogout string
|
||||
// sso-client logout callback url
|
||||
SsoLogoutCall string
|
||||
}
|
||||
|
||||
func DefaultApiName() *ApiName {
|
||||
return &ApiName{
|
||||
SsoAuth: "/sso/auth",
|
||||
SsoDoLogin: "/sso/doLogin",
|
||||
SsoCheckTicket: "/sso/checkTicket",
|
||||
SsoUserInfo: "/sso/userInfo",
|
||||
SsoSignout: "/sso/signout",
|
||||
SsoLogin: "/sso/login",
|
||||
SsoLogout: "/sso/logout",
|
||||
SsoLogoutCall: "/sso/logoutCall",
|
||||
}
|
||||
}
|
39
sso/param.go
Normal file
39
sso/param.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package sso
|
||||
|
||||
// ParamName http request param name.
|
||||
type ParamName struct {
|
||||
Redirect string
|
||||
Ticket string
|
||||
Back string
|
||||
Mode string
|
||||
LoginId string
|
||||
Client string
|
||||
SsoLogoutCall string
|
||||
Name string
|
||||
Pwd string
|
||||
|
||||
//=== sign param
|
||||
|
||||
TimeStamp string
|
||||
Nonce string
|
||||
Sign string
|
||||
SecretKet string
|
||||
}
|
||||
|
||||
func DefaultParamName() *ParamName {
|
||||
return &ParamName{
|
||||
Redirect: "redirect",
|
||||
Ticket: "ticket",
|
||||
Back: "back",
|
||||
Mode: "mode",
|
||||
LoginId: "loginId",
|
||||
Client: "client",
|
||||
SsoLogoutCall: "ssoLogoutCall",
|
||||
Name: "name",
|
||||
Pwd: "pwd",
|
||||
TimeStamp: "timestamp",
|
||||
Nonce: "nonce",
|
||||
Sign: "sign",
|
||||
SecretKet: "key",
|
||||
}
|
||||
}
|
108
sso/sso.go
Normal file
108
sso/sso.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"errors"
|
||||
tokenGo "github.com/weloe/token-go"
|
||||
"github.com/weloe/token-go/config"
|
||||
"github.com/weloe/token-go/constant"
|
||||
"github.com/weloe/token-go/ctx"
|
||||
"github.com/weloe/token-go/model"
|
||||
)
|
||||
|
||||
// Options construct options
|
||||
type Options struct {
|
||||
SsoOptions *config.SsoOptions
|
||||
SignOptions *config.SignOptions
|
||||
Enforcer tokenGo.IEnforcer
|
||||
}
|
||||
|
||||
type SsoEnforcer struct {
|
||||
apiName *ApiName
|
||||
paramName *ParamName
|
||||
config *config.SsoConfig
|
||||
signConfig *config.SignConfig
|
||||
enforcer tokenGo.IEnforcer
|
||||
}
|
||||
|
||||
// NewSsoEnforcer create sso enforcer.
|
||||
// If the available field in the parameter is empty, use the default value,
|
||||
// if the required field is empty, return nil and error.
|
||||
func NewSsoEnforcer(options *Options) (*SsoEnforcer, error) {
|
||||
if options.Enforcer == nil {
|
||||
return nil, errors.New("Options.Enforcer can not be nil")
|
||||
}
|
||||
if options.SsoOptions.CookieDomain != "" {
|
||||
options.Enforcer.GetTokenConfig().CookieConfig.Domain = options.SsoOptions.CookieDomain
|
||||
}
|
||||
|
||||
ssoConfig, err := config.NewSsoConfig(options.SsoOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signConfig, err := config.NewSignConfig(options.SignOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &SsoEnforcer{
|
||||
apiName: DefaultApiName(),
|
||||
paramName: DefaultParamName(),
|
||||
config: ssoConfig,
|
||||
signConfig: signConfig,
|
||||
enforcer: options.Enforcer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) SetApi(apiName *ApiName) {
|
||||
s.apiName = apiName
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) GetApi() *ApiName {
|
||||
return s.apiName
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) SetParamName(paramName *ParamName) {
|
||||
s.paramName = paramName
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) GetParamName() *ParamName {
|
||||
return s.paramName
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) SetSsoConfig(config *config.SsoConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) GetSsoConfig() *config.SsoConfig {
|
||||
return s.config
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) SetSignConfig(config *config.SignConfig) {
|
||||
s.signConfig = config
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) GetSignConfig() *config.SignConfig {
|
||||
return s.signConfig
|
||||
}
|
||||
|
||||
// ssoLogoutBack single-logout callback, redirect to ParamName.Back url.
|
||||
// If http request has back param and value is SELF, redirect to previous page,
|
||||
// if http request back param and value is validated url, redirect to url,
|
||||
// else return model.Ok() directly.
|
||||
func (s *SsoEnforcer) ssoLogoutBack(ctx ctx.Context) (interface{}, error) {
|
||||
paramName := s.paramName
|
||||
request := ctx.Request()
|
||||
response := ctx.Response()
|
||||
back := request.Query(paramName.Back)
|
||||
if back != "" {
|
||||
if back == constant.SELF {
|
||||
return "<script>if(document.referrer != location.href){ location.replace(document.referrer || '/'); }</script>", nil
|
||||
}
|
||||
response.Redirect(back)
|
||||
return nil, nil
|
||||
} else {
|
||||
// back is nil
|
||||
return model.Ok().SetMsg("back is nil, not redirect"), nil
|
||||
}
|
||||
|
||||
}
|
206
sso/sso_client_api.go
Normal file
206
sso/sso_client_api.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/weloe/token-go/ctx"
|
||||
"github.com/weloe/token-go/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
/**
|
||||
=========processor SSO-Client api
|
||||
*/
|
||||
|
||||
// SsoClientLogin SSO-Client login.
|
||||
func (s *SsoEnforcer) SsoClientLogin(ctx ctx.Context) (interface{}, error) {
|
||||
request := ctx.Request()
|
||||
response := ctx.Response()
|
||||
paramName := s.paramName
|
||||
apiName := s.apiName
|
||||
|
||||
// get back value and redirect
|
||||
back := request.Query(paramName.Back)
|
||||
if back == "" {
|
||||
back = "/"
|
||||
}
|
||||
isLogin, err := s.enforcer.IsLogin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// if the current client is already logged in, there is no need to redirect to the SSO-Server
|
||||
if isLogin {
|
||||
response.Redirect(back)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// if isLogin == false, attempt to get ticket
|
||||
ticket := request.Query(paramName.Ticket)
|
||||
|
||||
// if ticket == "", redirect to SSO-Server, the default path is /sso/auth
|
||||
if ticket == "" {
|
||||
serverAuthUrl, err := s.buildServerAuthUrl(request.UrlNoQuery(), back)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response.Redirect(serverAuthUrl)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// else ticket != "", need to log in by ticket
|
||||
var loginId string
|
||||
|
||||
// get current path
|
||||
ssoLoginUrl := apiName.SsoLogin
|
||||
if s.config.IsHttp {
|
||||
// use http request to SSO-Server to check ticket in SSO-Server
|
||||
var ssoLogoutCall string
|
||||
if s.config.IsSlo {
|
||||
// get logout callback url
|
||||
if s.config.SsoLogoutCall != "" {
|
||||
ssoLogoutCall = s.config.SsoLogoutCall
|
||||
} else if ssoLoginUrl != "" {
|
||||
ssoLogoutCall = strings.ReplaceAll(request.UrlNoQuery(), ssoLoginUrl, apiName.SsoLogoutCall)
|
||||
}
|
||||
}
|
||||
// send http to check
|
||||
checkTicketUrl, err := s.buildCheckTicketUrl(ticket, ssoLogoutCall)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// send http request
|
||||
resp, err := s.request(checkTicketUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.Code == model.ERROR {
|
||||
return nil, errors.New("request failed: " + resp.Msg)
|
||||
}
|
||||
loginId = resp.Data.(string)
|
||||
} else {
|
||||
// use adapter to check ticket
|
||||
loginId, err = s.CheckTicket(ticket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// if set callback TicketResultHandle
|
||||
if s.config.TicketResultHandle != nil {
|
||||
return s.config.TicketResultHandle(loginId, back)
|
||||
}
|
||||
|
||||
// if loginId == "", return error
|
||||
if loginId == "" {
|
||||
return nil, errors.New("invalid ticket: " + ticket)
|
||||
}
|
||||
// login
|
||||
_, err = s.enforcer.Login(loginId, ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// redirect to back
|
||||
response.Redirect(back)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// SsoClientLogout SSO-Client single-logout.
|
||||
func (s *SsoEnforcer) SsoClientLogout(ctx ctx.Context) (interface{}, error) {
|
||||
ssoConfig := s.config
|
||||
|
||||
// enable single-logout and isHttp == false
|
||||
if ssoConfig.IsSlo && !ssoConfig.IsHttp {
|
||||
isLogin, err := s.enforcer.IsLogin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// check if you are logged in
|
||||
if isLogin {
|
||||
var id string
|
||||
id, err = s.enforcer.GetLoginId(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// client logout
|
||||
err = s.enforcer.LogoutById(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// callback
|
||||
return s.ssoLogoutBack(ctx)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// enable single-logout and isHttp
|
||||
if ssoConfig.IsSlo && ssoConfig.IsHttp {
|
||||
isLogin, err := s.enforcer.IsLogin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !isLogin {
|
||||
return s.ssoLogoutBack(ctx)
|
||||
}
|
||||
|
||||
// get id
|
||||
var id string
|
||||
id, err = s.enforcer.GetLoginId(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// use SSO-Server single-logout api to logout
|
||||
sloUrl, err := s.buildSloUrl(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := s.request(sloUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// check response data
|
||||
if res.Code == model.SUCCESS {
|
||||
login, _ := s.enforcer.IsLogin(ctx)
|
||||
if login {
|
||||
err := s.enforcer.Logout(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return s.ssoLogoutBack(ctx)
|
||||
} else {
|
||||
return nil, errors.New("request failed: " + res.Msg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("not handle")
|
||||
}
|
||||
|
||||
// SsoClientLogoutCall client logout callback.
|
||||
func (s *SsoEnforcer) SsoClientLogoutCall(ctx ctx.Context) (interface{}, error) {
|
||||
request := ctx.Request()
|
||||
loginId := request.Query(s.paramName.LoginId)
|
||||
if loginId == "" {
|
||||
return nil, errors.New("request param must include loginId")
|
||||
}
|
||||
// check request param
|
||||
err := s.checkRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// logout
|
||||
err = s.enforcer.Logout(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return model.Ok().SetMsg("logout callback success"), nil
|
||||
}
|
||||
|
||||
// GetData client build url to sent http to get data from SSO-Server.
|
||||
func (s *SsoEnforcer) GetData(paramMap map[string]string) (interface{}, error) {
|
||||
finalUrl, err := s.buildGetDataUrl(paramMap)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := s.config.SendHttp(finalUrl)
|
||||
return res, err
|
||||
}
|
64
sso/sso_dispatcher_api.go
Normal file
64
sso/sso_dispatcher_api.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"github.com/weloe/token-go/ctx"
|
||||
"github.com/weloe/token-go/model"
|
||||
)
|
||||
|
||||
/**
|
||||
=========dispatcher api
|
||||
*/
|
||||
|
||||
// ServerDisPatcher dispatcher SSO-Server api, returns model.Result or string.
|
||||
func (s *SsoEnforcer) ServerDisPatcher(ctx ctx.Context) interface{} {
|
||||
request := ctx.Request()
|
||||
apiName := s.apiName
|
||||
path := request.Path()
|
||||
var res interface{}
|
||||
var err error
|
||||
if path == apiName.SsoAuth {
|
||||
res, err = s.SsoAuth(ctx)
|
||||
} else if path == apiName.SsoDoLogin {
|
||||
res, err = s.SsoDoLogin(ctx)
|
||||
} else if path == apiName.SsoCheckTicket && s.config.IsHttp {
|
||||
res, err = s.SsoCheckTicket(ctx)
|
||||
} else if path == apiName.SsoSignout {
|
||||
res, err = s.SsoSignOut(ctx)
|
||||
} else {
|
||||
return model.Error().SetMsg("not handle")
|
||||
}
|
||||
if err != nil {
|
||||
return model.Error().SetMsg(err.Error())
|
||||
}
|
||||
if res == nil {
|
||||
return model.Ok()
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// ClientDispatcher dispatcher Client api, returns model.Result or string.
|
||||
func (s *SsoEnforcer) ClientDispatcher(ctx ctx.Context) interface{} {
|
||||
request := ctx.Request()
|
||||
apiName := s.apiName
|
||||
path := request.Path()
|
||||
var res interface{}
|
||||
var err error
|
||||
|
||||
if path == apiName.SsoLogin {
|
||||
res, err = s.SsoClientLogin(ctx)
|
||||
} else if path == apiName.SsoLogout {
|
||||
res, err = s.SsoClientLogout(ctx)
|
||||
} else if path == apiName.SsoLogoutCall && s.config.IsSlo && s.config.IsHttp {
|
||||
res, err = s.SsoClientLogoutCall(ctx)
|
||||
} else {
|
||||
return model.Error().SetMsg("not handle")
|
||||
}
|
||||
if err != nil {
|
||||
return model.Error().SetMsg(err.Error())
|
||||
}
|
||||
if res == nil {
|
||||
return model.Ok()
|
||||
}
|
||||
return res
|
||||
}
|
494
sso/sso_internal_api.go
Normal file
494
sso/sso_internal_api.go
Normal file
@@ -0,0 +1,494 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/weloe/token-go/constant"
|
||||
"github.com/weloe/token-go/ctx"
|
||||
"github.com/weloe/token-go/model"
|
||||
"github.com/weloe/token-go/util"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
/**
|
||||
=========internal api
|
||||
*/
|
||||
|
||||
// checkRequest check request param timestamp,nonce,sign.
|
||||
func (s *SsoEnforcer) checkRequest(request ctx.Request) error {
|
||||
timestamp := request.Query(s.paramName.TimeStamp)
|
||||
nonce := request.Query(s.paramName.Nonce)
|
||||
sign := request.Query(s.paramName.Sign)
|
||||
err := s.checkTimeStamp(timestamp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.signConfig.IsCheckNonce {
|
||||
err = s.checkNonce(nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = s.checkSign(timestamp, nonce, sign)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateTicket create ticket by account-id.
|
||||
func (s *SsoEnforcer) CreateTicket(loginId string, client string) (string, error) {
|
||||
// create random string ticket
|
||||
ticket, err := util.GenerateRandomString64()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// save ticket-id+client
|
||||
err = s.saveTicket(ticket, loginId, client)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// save id-ticket
|
||||
err = s.saveTicketIndex(ticket, loginId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return ticket, nil
|
||||
}
|
||||
|
||||
// GetLoginId get loginId by ticket.
|
||||
func (s *SsoEnforcer) GetLoginId(ticket string) string {
|
||||
if ticket == "" {
|
||||
return ""
|
||||
}
|
||||
loginId := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
|
||||
if loginId != "" && strings.Contains(loginId, ",") {
|
||||
split := strings.Split(loginId, ",")
|
||||
loginId = split[0]
|
||||
}
|
||||
return loginId
|
||||
}
|
||||
|
||||
// GetTicket get ticket by loginId.
|
||||
func (s *SsoEnforcer) GetTicket(loginId string) string {
|
||||
if loginId == "" {
|
||||
return ""
|
||||
}
|
||||
return s.enforcer.GetAdapter().GetStr(s.spliceTicketIndexKey(loginId))
|
||||
}
|
||||
|
||||
// CheckTicket use config.Client to check ticket,return loginId.
|
||||
func (s *SsoEnforcer) CheckTicket(ticket string) (string, error) {
|
||||
return s.CheckTicketByClient(ticket, s.config.Client)
|
||||
}
|
||||
|
||||
// CheckTicketByClient check ticket by pointing client,return loginId.
|
||||
func (s *SsoEnforcer) CheckTicketByClient(ticket string, client string) (string, error) {
|
||||
id := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
|
||||
if id == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var ticketClient string
|
||||
if strings.Contains(id, ",") {
|
||||
split := strings.Split(id, ",")
|
||||
id = split[0]
|
||||
ticketClient = split[1]
|
||||
}
|
||||
|
||||
if client != "" && client != ticketClient {
|
||||
return "", fmt.Errorf("the ticket does not belong to the client, client: %v, ticket: %v", client, ticket)
|
||||
}
|
||||
err := s.deleteTicket(ticket)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
err = s.deleteTicketIndex(id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// CheckRedirectUrl check redirectUrl.
|
||||
func (s *SsoEnforcer) CheckRedirectUrl(url string) error {
|
||||
if !util.IsValidUrl(url) {
|
||||
return fmt.Errorf("invalid redirect url: %v", url)
|
||||
}
|
||||
index := strings.Index(url, "?")
|
||||
if index != -1 {
|
||||
url = url[0:index]
|
||||
}
|
||||
allowUrls := strings.Split(s.GetAllowUrl(), ",")
|
||||
|
||||
if !util.HasUrl(allowUrls, url) {
|
||||
return fmt.Errorf("illegal redirect url: %v", url)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterSloCallbackUrl register the URL of the single logout callback for the account id.
|
||||
func (s *SsoEnforcer) RegisterSloCallbackUrl(loginId string, sloCallbackUrl string) error {
|
||||
if loginId == "" || sloCallbackUrl == "" {
|
||||
return nil
|
||||
}
|
||||
session := s.enforcer.GetSession(loginId)
|
||||
// splice session id
|
||||
sessionId := s.enforcer.GetTokenConfig().TokenName + ":" + s.enforcer.GetType() + ":session:" + loginId
|
||||
if session == nil {
|
||||
session = model.NewSession(sessionId, "account-session", loginId)
|
||||
}
|
||||
value := session.Get(constant.SLO_CALLBACK_SET_KEY)
|
||||
|
||||
var v []string
|
||||
if value != nil {
|
||||
sv, ok := value.([]string)
|
||||
v = sv
|
||||
if !ok {
|
||||
return errors.New("session SLO_CALLBACK_SET_KEY_ data convert into []string failed")
|
||||
}
|
||||
}
|
||||
v = util.AppendStr(v, sloCallbackUrl)
|
||||
|
||||
session.Set(sessionId, v)
|
||||
// update session
|
||||
err := s.enforcer.UpdateSession(loginId, session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ssoSignOutById single sign-out of the specified account.
|
||||
// Use loginId to get single sign-out urls from session.
|
||||
func (s *SsoEnforcer) ssoSignOutById(loginId string) error {
|
||||
// if loginId is not logged, return error
|
||||
session := s.enforcer.GetSession(loginId)
|
||||
if session == nil {
|
||||
return errors.New("this loginId is not logged in")
|
||||
}
|
||||
value := session.Get(constant.SLO_CALLBACK_SET_KEY)
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
urls, ok := value.([]string)
|
||||
if !ok {
|
||||
return errors.New("convert into []string failed")
|
||||
}
|
||||
// range urls to make client logout
|
||||
for _, url := range urls {
|
||||
// join url
|
||||
newUrl, err := s.joinLoginIdAndSign(url, loginId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// sent http
|
||||
_, err = s.config.SendHttp(newUrl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// server logout
|
||||
return s.enforcer.LogoutById(loginId)
|
||||
}
|
||||
|
||||
// joinLoginIdAndSign splice the loginId to the url, and stitching parameters such as sign.
|
||||
func (s *SsoEnforcer) joinLoginIdAndSign(url string, id string) (string, error) {
|
||||
nonce, err := util.GenerateRandomString32()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||
sign, err := s.createSign(timestamp, nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
str := url + "?" + s.paramName.LoginId + "=" + id + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.Nonce + "=" + nonce + "&" + s.paramName.Sign + "=" + sign
|
||||
return str, nil
|
||||
}
|
||||
|
||||
// buildServerAuthUrl SSO-Client build SSO-Server single sign-on url.
|
||||
func (s *SsoEnforcer) buildServerAuthUrl(clientLoginUrl string, back string) (string, error) {
|
||||
if clientLoginUrl == "" {
|
||||
return "", errors.New("arg[0] clientLoginUrl can not be nil")
|
||||
}
|
||||
|
||||
// get server auth url
|
||||
authUrl := s.config.SpliceAuthUrl()
|
||||
|
||||
client := s.config.Client
|
||||
|
||||
if client != "" {
|
||||
authUrl = util.AddQuery(authUrl, s.paramName.Client, client)
|
||||
}
|
||||
|
||||
// splice back url
|
||||
if back != "" {
|
||||
back = util.Encode(back)
|
||||
clientLoginUrl = util.AddQuery(clientLoginUrl, s.paramName.Back, back)
|
||||
}
|
||||
|
||||
return util.AddQuery(authUrl, s.paramName.Redirect, clientLoginUrl), nil
|
||||
}
|
||||
|
||||
// buildRedirectUrl the server gives the redirectUrl of the ticket to the client.
|
||||
// Check redirect url, delete old ticket of loginId and create new ticket, then return url with new ticket.
|
||||
func (s *SsoEnforcer) buildRedirectUrl(loginId string, client string, redirect string) (string, error) {
|
||||
// check redirect url
|
||||
err := s.CheckRedirectUrl(redirect)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// delete old ticket
|
||||
err = s.deleteTicket(s.GetTicket(loginId))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// create new ticket
|
||||
ticket, err := s.CreateTicket(loginId, client)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// return redirect + "?" + s.paramName.Ticket + "=" + ticket, nil
|
||||
return util.AddQuery(s.encodeBackParam(redirect), s.paramName.Ticket, ticket), nil
|
||||
}
|
||||
|
||||
// encodeBackParam find back param from url, and encode back param.
|
||||
func (s *SsoEnforcer) encodeBackParam(url string) string {
|
||||
// get back location
|
||||
index := strings.Index(url, "?"+s.paramName.Back+"=")
|
||||
if index == -1 {
|
||||
index = strings.Index(url, "&"+s.paramName.Back+"=")
|
||||
if index == -1 {
|
||||
return url
|
||||
}
|
||||
}
|
||||
|
||||
// encode
|
||||
length := len(s.paramName.Back) + 2
|
||||
back := url[index+length:]
|
||||
back = util.Encode(back)
|
||||
|
||||
// update back
|
||||
url = url[:index+length] + back
|
||||
return url
|
||||
}
|
||||
|
||||
// buildCheckTicketUrl build to check ticket.
|
||||
func (s *SsoEnforcer) buildCheckTicketUrl(ticket string, ssoLogoutCallUrl string) (string, error) {
|
||||
if ticket == "" {
|
||||
return "", errors.New("buildCheckTicketUrl() ticket can not be nil")
|
||||
}
|
||||
checkTicketUrl := s.config.SpliceCheckTicketUrl()
|
||||
client := s.config.Client
|
||||
paramMap := make(map[string]string)
|
||||
if client != "" {
|
||||
paramMap[s.paramName.Client] = client
|
||||
}
|
||||
paramMap[s.paramName.Ticket] = ticket
|
||||
if ssoLogoutCallUrl != "" {
|
||||
paramMap[s.paramName.SsoLogoutCall] = ssoLogoutCallUrl
|
||||
}
|
||||
|
||||
return util.AddQueryMap(checkTicketUrl, paramMap), nil
|
||||
}
|
||||
|
||||
// buildSloUrl build single-logout url.
|
||||
func (s *SsoEnforcer) buildSloUrl(loginId string) (string, error) {
|
||||
sloUrl := s.config.SpliceSloUrl()
|
||||
url, err := s.joinLoginIdAndSign(sloUrl, loginId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// buildGetDataUrl build getData url with sign,timestamp,nonce.
|
||||
func (s *SsoEnforcer) buildGetDataUrl(paramMap map[string]string) (string, error) {
|
||||
getDataUrl := s.config.SpliceGetDataUrl()
|
||||
return s.buildCustomPathUrl(getDataUrl, paramMap)
|
||||
}
|
||||
|
||||
// buildCustomPathUrl add paramMap to path.
|
||||
func (s *SsoEnforcer) buildCustomPathUrl(path string, paramMap map[string]string) (string, error) {
|
||||
u := path
|
||||
|
||||
if !strings.HasPrefix(u, "http") {
|
||||
serverUrl := s.config.ServerUrl
|
||||
if serverUrl == "" {
|
||||
return "", errors.New("please set sso serverUrl")
|
||||
}
|
||||
u = util.SpliceUrl(serverUrl, path)
|
||||
}
|
||||
// sign map
|
||||
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||
paramMap[s.paramName.TimeStamp] = timestamp
|
||||
nonce, err := util.GenerateRandomString32()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
paramMap[s.paramName.Nonce] = nonce
|
||||
// create sign
|
||||
sign, err := s.createSign(timestamp, nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
paramMap[s.paramName.Sign] = sign
|
||||
finalUrl := util.AddQueryMap(u, paramMap)
|
||||
return finalUrl, nil
|
||||
}
|
||||
|
||||
// request send http request and use json.Unmarshal to converted to *model.Result.
|
||||
func (s *SsoEnforcer) request(url string) (*model.Result, error) {
|
||||
resp, err := s.config.SendHttp(url)
|
||||
|
||||
log.Printf("http request response: %s", resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result := &model.Result{}
|
||||
err = json.Unmarshal([]byte(resp), result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) GetAllowUrl() string {
|
||||
return s.config.AllowUrl
|
||||
}
|
||||
|
||||
// saveTicket save ticket-id+client.
|
||||
func (s *SsoEnforcer) saveTicket(ticket string, loginId string, client string) error {
|
||||
value := loginId
|
||||
if client != "" {
|
||||
value += "," + client
|
||||
}
|
||||
ticketTimeout := s.config.TicketTimeout
|
||||
return s.enforcer.GetAdapter().SetStr(s.spliceTicketSaveKey(ticket), value, ticketTimeout)
|
||||
}
|
||||
|
||||
// saveTicketIndex save id-ticket.
|
||||
func (s *SsoEnforcer) saveTicketIndex(ticket string, id string) error {
|
||||
ticketTimeout := s.config.TicketTimeout
|
||||
return s.enforcer.GetAdapter().SetStr(s.spliceTicketIndexKey(id), ticket, ticketTimeout)
|
||||
}
|
||||
|
||||
// spliceTicketSaveKey splice ticket-id key.
|
||||
func (s *SsoEnforcer) spliceTicketSaveKey(ticket string) string {
|
||||
return s.enforcer.GetTokenConfig().TokenName + ":ticket:" + ticket
|
||||
}
|
||||
|
||||
// spliceTicketIndexKey splice id-ticket key.
|
||||
func (s *SsoEnforcer) spliceTicketIndexKey(id string) string {
|
||||
return s.enforcer.GetTokenConfig().TokenName + ":id-ticket:" + id
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) deleteTicket(ticket string) error {
|
||||
if ticket == "" {
|
||||
return nil
|
||||
}
|
||||
return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketSaveKey(ticket))
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) deleteTicketIndex(id string) error {
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketIndexKey(id))
|
||||
}
|
||||
|
||||
// checkTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
|
||||
func (s *SsoEnforcer) checkTimeStamp(timestamp string) error {
|
||||
parseInt, err := strconv.ParseInt(timestamp, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !s.isValidTimeStamp(parseInt) {
|
||||
return errors.New("timestamp is out of allowed range: " + timestamp)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
|
||||
func (s *SsoEnforcer) isValidTimeStamp(timestamp int64) bool {
|
||||
allowDisparity := s.signConfig.TimeStampDisparity
|
||||
nowDisparity := time.Now().UnixMilli() - timestamp
|
||||
|
||||
return allowDisparity == 1 || nowDisparity <= allowDisparity
|
||||
}
|
||||
|
||||
// checkNonce the same nonce can only be verified once.
|
||||
func (s *SsoEnforcer) checkNonce(nonce string) error {
|
||||
if nonce == "" {
|
||||
return errors.New("nonce is nil")
|
||||
}
|
||||
// if nonce exists in adapter
|
||||
if !s.isValidNonce(nonce) {
|
||||
return errors.New("the nonce has been used: " + nonce)
|
||||
}
|
||||
// set nonce after nonce
|
||||
err := s.enforcer.GetAdapter().SetStr(s.spliceNonceSaveKey(nonce), nonce, s.signConfig.GetSaveNonceExpire()*2+2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidNonce determine random string, if not exist in adapter, return true.
|
||||
func (s *SsoEnforcer) isValidNonce(nonce string) bool {
|
||||
if nonce == "" {
|
||||
return false
|
||||
}
|
||||
key := s.spliceNonceSaveKey(nonce)
|
||||
// if not exist in adapter, return true.
|
||||
return s.enforcer.GetAdapter().GetStr(key) == ""
|
||||
}
|
||||
|
||||
func (s *SsoEnforcer) checkSign(timestamp string, nonce string, sign string) error {
|
||||
valid, err := s.isValidSign(timestamp, nonce, sign)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !valid {
|
||||
return errors.New("invalid sign: " + sign)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isValidSign use timestamp,nonce,sign to createSign to compare.
|
||||
func (s *SsoEnforcer) isValidSign(timestamp string, nonce string, sign string) (bool, error) {
|
||||
recreateSign, err := s.createSign(timestamp, nonce)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return recreateSign == sign, nil
|
||||
}
|
||||
|
||||
// createSign use util.MD5() to generate str.
|
||||
func (s *SsoEnforcer) createSign(timestamp string, nonce string) (string, error) {
|
||||
secretKey := s.signConfig.SecretKey
|
||||
if secretKey == "" {
|
||||
return "", errors.New("please check SignConfig.SecretKey, SecretKey can not be nil")
|
||||
}
|
||||
str := s.paramName.Nonce + "=" + nonce + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.SecretKet + "=" + secretKey
|
||||
|
||||
return util.MD5(str), nil
|
||||
}
|
||||
|
||||
// spliceNonceSaveKey splice nonce store key.
|
||||
func (s *SsoEnforcer) spliceNonceSaveKey(nonce string) string {
|
||||
return s.enforcer.GetTokenConfig().TokenName + ":sign:nonce:" + nonce
|
||||
}
|
135
sso/sso_server_api.go
Normal file
135
sso/sso_server_api.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/weloe/token-go/constant"
|
||||
"github.com/weloe/token-go/ctx"
|
||||
"github.com/weloe/token-go/model"
|
||||
)
|
||||
|
||||
/**
|
||||
=========processor SSO-Server api
|
||||
*/
|
||||
|
||||
// SsoAuth SSO-Server: auth.
|
||||
func (s *SsoEnforcer) SsoAuth(ctx ctx.Context) (interface{}, error) {
|
||||
request := ctx.Request()
|
||||
response := ctx.Response()
|
||||
|
||||
isLogin, err := s.enforcer.IsLogin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// if you have not logged in to the SSO-Server, need to log in first
|
||||
if !isLogin {
|
||||
return s.config.NotLoginView(), nil
|
||||
}
|
||||
// if you have logged, check the mode
|
||||
mode := request.Query(s.paramName.Mode)
|
||||
|
||||
redirect := request.Query(s.paramName.Redirect)
|
||||
|
||||
// if mode == simple, redirect to client directly
|
||||
if mode == constant.MODE_SIMPLE {
|
||||
err = s.CheckRedirectUrl(redirect)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response.Redirect(redirect)
|
||||
return nil, nil
|
||||
} else {
|
||||
// mode = ticket, redirect to client login with new ticket
|
||||
id, err := s.enforcer.GetLoginId(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redirectUrl, err := s.buildRedirectUrl(id, request.Query(s.paramName.Client), redirect)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
response.Redirect(redirectUrl)
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SsoDoLogin SSO-Server: rest login api.
|
||||
func (s *SsoEnforcer) SsoDoLogin(ctx ctx.Context) (interface{}, error) {
|
||||
request := ctx.Request()
|
||||
paramName := s.paramName
|
||||
if s.config.DoLoginHandle == nil {
|
||||
return nil, errors.New("SsoConfig.DoLoginHandle is nil")
|
||||
}
|
||||
resp, err := s.config.DoLoginHandle(request.Query(paramName.Name), request.Query(paramName.Pwd), ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// SsoCheckTicket SSO-Server: check ticket to get loginId, returns loginId
|
||||
func (s *SsoEnforcer) SsoCheckTicket(ctx ctx.Context) (interface{}, error) {
|
||||
paramName := s.paramName
|
||||
request := ctx.Request()
|
||||
client := request.Query(paramName.Client)
|
||||
ticket := request.Query(paramName.Ticket)
|
||||
if ticket == "" {
|
||||
return nil, errors.New("ticket can not be nil")
|
||||
}
|
||||
sloCallback := request.Query(paramName.SsoLogoutCall)
|
||||
|
||||
// check ticket
|
||||
loginId, err := s.CheckTicketByClient(ticket, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// register single sign out callback url
|
||||
err = s.RegisterSloCallbackUrl(loginId, sloCallback)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if loginId == "" {
|
||||
return nil, errors.New("invalid ticket: " + ticket)
|
||||
}
|
||||
return model.Ok().SetData(loginId), nil
|
||||
}
|
||||
|
||||
// SsoSignOut SSO-Server: single sign-out.
|
||||
func (s *SsoEnforcer) SsoSignOut(ctx ctx.Context) (interface{}, error) {
|
||||
request := ctx.Request()
|
||||
paramName := s.paramName
|
||||
|
||||
// if enable single sign-out and request param has loginId
|
||||
reqLoginId := request.Query(paramName.LoginId)
|
||||
if s.config.IsSlo && reqLoginId == "" {
|
||||
loginId, err := s.enforcer.GetLoginId(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if loginId != "" {
|
||||
err = s.ssoSignOutById(loginId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// callback
|
||||
return s.ssoLogoutBack(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// if enable http,single sign-out and request param has loginId
|
||||
if s.config.IsHttp && s.config.IsSlo && reqLoginId != "" {
|
||||
err := s.checkRequest(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Use loginId to get single sign-out urls from session, traverse the urls to send the request to notify client
|
||||
err = s.ssoSignOutById(reqLoginId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return model.Ok().SetMsg("sso sign-out success"), nil
|
||||
}
|
||||
|
||||
return nil, errors.New("not handle")
|
||||
}
|
103
sso/sso_test.go
Normal file
103
sso/sso_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
tokenGo "github.com/weloe/token-go"
|
||||
"github.com/weloe/token-go/config"
|
||||
"github.com/weloe/token-go/ctx"
|
||||
"github.com/weloe/token-go/model"
|
||||
"github.com/weloe/token-go/util"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func SendGetRequest(url string) (string, error) {
|
||||
return util.SendGetRequest(url)
|
||||
}
|
||||
|
||||
func TestNewSsoServerEnforcer(t *testing.T) {
|
||||
var err error
|
||||
// use default adapter
|
||||
adapter := tokenGo.NewDefaultAdapter()
|
||||
enforcer, err := tokenGo.NewEnforcer(adapter)
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
}
|
||||
// enable logger
|
||||
enforcer.EnableLog()
|
||||
ssoOptions := &config.SsoOptions{
|
||||
Mode: "",
|
||||
TicketTimeout: 300,
|
||||
AllowUrl: "*",
|
||||
IsSlo: true,
|
||||
|
||||
IsHttp: true,
|
||||
ServerUrl: "http://token-go-sso-server.com:9000",
|
||||
NotLoginView: func() interface{} {
|
||||
msg := "not log in SSO-Server, please visit <a href='/sso/doLogin?name=tokengo&pwd=123456' target='_blank'> doLogin </a>"
|
||||
return msg
|
||||
},
|
||||
DoLoginHandle: func(name string, pwd string, ctx ctx.Context) (interface{}, error) {
|
||||
if name != "tokengo" {
|
||||
return "name error", nil
|
||||
}
|
||||
if pwd != "123456" {
|
||||
return "pwd error", nil
|
||||
}
|
||||
token, err := enforcer.Login("1001", ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return model.Ok().SetData(token), nil
|
||||
},
|
||||
SendHttp: func(url string) (string, error) {
|
||||
return SendGetRequest(url)
|
||||
},
|
||||
}
|
||||
signOptions := &config.SignOptions{
|
||||
SecretKey: "kQwIOrYvnXmSDkwEiFngrKidMcdrgKor",
|
||||
IsCheckNonce: true,
|
||||
}
|
||||
ssoEnforcer, err := NewSsoEnforcer(&Options{
|
||||
SsoOptions: ssoOptions,
|
||||
SignOptions: signOptions,
|
||||
Enforcer: enforcer,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("NewSsoEnforcer() failed: %v", err)
|
||||
}
|
||||
t.Logf("enforcer: %v", ssoEnforcer)
|
||||
}
|
||||
func TestNewSsoClient3Enforcer(t *testing.T) {
|
||||
var err error
|
||||
// use default adapter
|
||||
adapter := tokenGo.NewDefaultAdapter()
|
||||
enforcer, err := tokenGo.NewEnforcer(adapter)
|
||||
if err != nil {
|
||||
t.Errorf("NewEnforcer() failed: %v", err)
|
||||
}
|
||||
// enable logger
|
||||
enforcer.EnableLog()
|
||||
ssoOptions := &config.SsoOptions{
|
||||
AuthUrl: "/sso/auth",
|
||||
IsSlo: true,
|
||||
IsHttp: true,
|
||||
SloUrl: "/sso/signout",
|
||||
CheckTicketUrl: "/sso/checkTicket",
|
||||
ServerUrl: "http://token-go-sso-server.com:9000",
|
||||
SendHttp: func(url string) (string, error) {
|
||||
return SendGetRequest(url)
|
||||
},
|
||||
}
|
||||
signOptions := &config.SignOptions{
|
||||
SecretKey: "kQwIOrYvnXmSDkwEiFngrKidMcdrgKor",
|
||||
IsCheckNonce: true,
|
||||
}
|
||||
ssoEnforcer, err := NewSsoEnforcer(&Options{
|
||||
SsoOptions: ssoOptions,
|
||||
SignOptions: signOptions,
|
||||
Enforcer: enforcer,
|
||||
})
|
||||
if err != nil {
|
||||
t.Errorf("NewSsoEnforcer() failed: %v", err)
|
||||
}
|
||||
t.Logf("enforcer: %v", ssoEnforcer)
|
||||
}
|
13
util/secure.go
Normal file
13
util/secure.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func MD5(str string) string {
|
||||
data := []byte(str)
|
||||
has := md5.Sum(data)
|
||||
md5str := fmt.Sprintf("%x", has)
|
||||
return md5str
|
||||
}
|
12
util/sign.go
Normal file
12
util/sign.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// IsValidTimeStamp determine whether the gap between the startTime and the current timestamp is within the allowable range.
|
||||
func IsValidTimeStamp(startTime int64, allowDisparity int64) bool {
|
||||
nowDisparity := time.Now().UnixMilli() - startTime
|
||||
|
||||
return allowDisparity == 1 || nowDisparity <= allowDisparity
|
||||
}
|
137
util/url.go
Normal file
137
util/url.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func SendGetRequest(url string) (string, error) {
|
||||
response, err := http.Get(url)
|
||||
if err != nil {
|
||||
log.Printf("http.Get() failed: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
defer func(Body io.ReadCloser) {
|
||||
err = Body.Close()
|
||||
if err != nil {
|
||||
log.Printf("read response body failed: %v", err)
|
||||
}
|
||||
}(response.Body)
|
||||
|
||||
body, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
log.Printf("ioutil.ReadAll() failed: %v", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// SpliceUrl splice two url.
|
||||
// Examples:
|
||||
// u1 = "http://domain.com" u2 = "/sso/auth" return http://domain.com/sso/auth
|
||||
func SpliceUrl(u1 string, u2 string) string {
|
||||
if u1 == "" {
|
||||
return u2
|
||||
}
|
||||
if u2 == "" {
|
||||
return u1
|
||||
}
|
||||
|
||||
if strings.HasPrefix(u2, "http") {
|
||||
return u2
|
||||
}
|
||||
|
||||
return u1 + u2
|
||||
}
|
||||
|
||||
func HasUrl(urls []string, url string) bool {
|
||||
for _, s := range urls {
|
||||
if MatchUrl(s, url) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func MatchUrl(pattern string, url string) bool {
|
||||
if pattern == "*" {
|
||||
return true
|
||||
}
|
||||
return pattern == url
|
||||
}
|
||||
|
||||
func IsValidUrl(u1 string) bool {
|
||||
|
||||
_, err := url.ParseRequestURI(u1)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
u, err := url.Parse(u1)
|
||||
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// check if the URL has a valid scheme (http or https)
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AddQueryMap add map param for the path.
|
||||
func AddQueryMap(path string, paramMap map[string]string) string {
|
||||
queryString := MapToQuery(paramMap)
|
||||
|
||||
return AddQueryValue(path, queryString)
|
||||
}
|
||||
|
||||
func AddQueryValue(path string, queryString string) string {
|
||||
index := strings.LastIndex(path, "?")
|
||||
// if the path is not included "?"
|
||||
if index == -1 {
|
||||
return path + "?" + queryString
|
||||
}
|
||||
// if the last is "?"
|
||||
if index == len(path)-1 {
|
||||
return path + queryString
|
||||
}
|
||||
|
||||
// if "?" inside path, the last is not "&" and queryString's first string is not "&"
|
||||
if index < len(path)-1 {
|
||||
if strings.LastIndex(path, "&") != len(path)-1 && strings.Index(path, "&") != 0 {
|
||||
return path + "&" + queryString
|
||||
} else {
|
||||
return path + queryString
|
||||
}
|
||||
}
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// AddQuery add query param for the path.
|
||||
func AddQuery(path string, key string, value string) string {
|
||||
queryString := key + "=" + value
|
||||
return AddQueryValue(path, queryString)
|
||||
}
|
||||
|
||||
// MapToQuery convert map to k=v array, and use "&" to join.
|
||||
func MapToQuery(paramMap map[string]string) string {
|
||||
var queryString []string
|
||||
for k, v := range paramMap {
|
||||
queryString = append(queryString, k+"="+v)
|
||||
}
|
||||
query := strings.Join(queryString, "&")
|
||||
return query
|
||||
}
|
||||
|
||||
func Encode(u string) string {
|
||||
return url.QueryEscape(u)
|
||||
}
|
144
util/url_test.go
Normal file
144
util/url_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSpliceUrl(t *testing.T) {
|
||||
type args struct {
|
||||
u1 string
|
||||
u2 string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "success1",
|
||||
args: args{
|
||||
u1: "http://domain.com",
|
||||
u2: "/sso/auth",
|
||||
},
|
||||
want: "http://domain.com/sso/auth",
|
||||
},
|
||||
{
|
||||
name: "success2",
|
||||
args: args{
|
||||
u1: "",
|
||||
u2: "http://domain.com/sso/auth",
|
||||
},
|
||||
want: "http://domain.com/sso/auth",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := SpliceUrl(tt.args.u1, tt.args.u2); got != tt.want {
|
||||
t.Errorf("SpliceUrl() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasUrl(t *testing.T) {
|
||||
strings := []string{"1", "2", "3"}
|
||||
hasUrl := HasUrl(strings, "2")
|
||||
if !hasUrl {
|
||||
t.Errorf("HasUrl() = %v, want %v", false, true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchUrl(t *testing.T) {
|
||||
all := "*"
|
||||
if !MatchUrl(all, "123") {
|
||||
t.Errorf("HasUrl() = %v, want %v", false, true)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidUrl(t *testing.T) {
|
||||
type args struct {
|
||||
u1 string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
|
||||
{
|
||||
name: "error url1",
|
||||
args: args{u1: "htp:23/asd"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "validated",
|
||||
args: args{u1: "http://123.com:90//ac"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "error url2",
|
||||
args: args{u1: "http:/123.com:90//ac"},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsValidUrl(tt.args.u1); got != tt.want {
|
||||
t.Errorf("IsValidUrl() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddQueryMap(t *testing.T) {
|
||||
m := make(map[string]string)
|
||||
m["1"] = "2"
|
||||
m["2"] = "3"
|
||||
|
||||
if got := AddQueryMap("/sso/auth", m); got != "/sso/auth?1=2&2=3" && got != "/sso/auth?2=3&1=2" {
|
||||
t.Errorf("AddQueryMap() = %v, want %v", got, "/sso/auth?1=2&2=3")
|
||||
}
|
||||
|
||||
if got := AddQueryMap("/sso/auth?", m); got != "/sso/auth?1=2&2=3" && got != "/sso/auth?2=3&1=2" {
|
||||
t.Errorf("AddQueryMap() = %v, want %v", got, "/sso/auth?1=2&2=3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddQuery(t *testing.T) {
|
||||
if got := AddQuery("/sso", "1", "2"); got != "/sso?1=2" {
|
||||
t.Errorf("AddQuery() = %v, want %v", got, "/sso?1=2")
|
||||
}
|
||||
if got := AddQuery("/sso?", "1", "2"); got != "/sso?1=2" {
|
||||
t.Errorf("AddQuery() = %v, want %v", got, "/sso?1=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapToQuery(t *testing.T) {
|
||||
m := make(map[string]string)
|
||||
m["1"] = "2"
|
||||
query := MapToQuery(m)
|
||||
if query != "1=2" {
|
||||
t.Errorf("MapToQuery() = %v, want %v", query, "1=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncode(t *testing.T) {
|
||||
encode := Encode("abc123==123")
|
||||
log.Print("Encode(\"abc123==123\") = " + encode)
|
||||
}
|
||||
|
||||
func TestAddQueryValue(t *testing.T) {
|
||||
if got := AddQueryValue("/sso/auth?back=http://123.com/login", "ticket=23324"); got != "/sso/auth?back=http://123.com/login&ticket=23324" {
|
||||
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login&ticket=23324")
|
||||
}
|
||||
if got := AddQueryValue("/sso/auth?back=http://123.com/login?", "ticket=23324"); got != "/sso/auth?back=http://123.com/login?ticket=23324" {
|
||||
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=23324")
|
||||
}
|
||||
if got := AddQueryValue("/sso/auth?back=http://123.com/login?ticket=123", "redirect=23324"); got != "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324" {
|
||||
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324")
|
||||
}
|
||||
if got := AddQueryValue("/sso/auth?back=http://123.com/login?ticket=123&", "redirect=23324"); got != "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324" {
|
||||
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324")
|
||||
}
|
||||
}
|
@@ -26,3 +26,12 @@ func InterfaceToBytes(data interface{}) ([]byte, error) {
|
||||
}
|
||||
return nil, fmt.Errorf("unable to convert %T to []byte", data)
|
||||
}
|
||||
|
||||
// AppendStr do not add repeated str.
|
||||
// If old slice has newStr, return directly, else append
|
||||
func AppendStr(old []string, newStr string) []string {
|
||||
if HasStr(old, newStr) {
|
||||
return old
|
||||
}
|
||||
return append(old, newStr)
|
||||
}
|
||||
|
22
util/util_test.go
Normal file
22
util/util_test.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAppendStr(t *testing.T) {
|
||||
strings := []string{"1", "2"}
|
||||
str := AppendStr(strings, "1")
|
||||
str = AppendStr(str, "3")
|
||||
for i, s := range strings {
|
||||
if s == "1" && i == 2 {
|
||||
t.Errorf("AppendStr() = %v, want %v", str, []string{"1", "2"})
|
||||
}
|
||||
|
||||
}
|
||||
for i, s := range strings {
|
||||
if i == 2 && s != "3" {
|
||||
t.Errorf("AppendStr() = %v, want %v", str, []string{"1", "2", "3"})
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user