mirror of
https://github.com/weloe/token-go.git
synced 2025-10-07 00:12:58 +08:00
feat: support SSO
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
# Token-Go
|
# Token-Go
|
||||||
|
|
||||||
This library focuses on solving login authentication problems, such as: login, multi-account login, shared token, logout, kickout ...
|
This library focuses on solving login authentication problems, such as: login, multi-account login, shared token, logout, kickout, banned, SSO ...
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
@@ -189,6 +189,9 @@ func CheckAuth(w http.ResponseWriter, req *http.Request) {
|
|||||||
fmt.Fprintf(w, "you have authorization")
|
fmt.Fprintf(w, "you have authorization")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
## SSO
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Api
|
## Api
|
||||||
|
|
||||||
|
30
config/sign.go
Normal file
30
config/sign.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
// SignConfig sign config
|
||||||
|
type SignConfig struct {
|
||||||
|
SecretKey string
|
||||||
|
TimeStampDisparity int64
|
||||||
|
IsCheckNonce bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSignConfig(options *SignOptions) (*SignConfig, error) {
|
||||||
|
if options == nil {
|
||||||
|
options = &SignOptions{}
|
||||||
|
}
|
||||||
|
if options.TimeStampDisparity == 0 {
|
||||||
|
options.TimeStampDisparity = 1000 * 60 * 15
|
||||||
|
}
|
||||||
|
return &SignConfig{
|
||||||
|
SecretKey: options.SecretKey,
|
||||||
|
TimeStampDisparity: options.TimeStampDisparity,
|
||||||
|
IsCheckNonce: options.IsCheckNonce,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SignConfig) GetSaveNonceExpire() int64 {
|
||||||
|
if s.TimeStampDisparity >= 0 {
|
||||||
|
return s.TimeStampDisparity / 1000
|
||||||
|
} else {
|
||||||
|
return 60 * 60 * 24
|
||||||
|
}
|
||||||
|
}
|
168
config/sso.go
Normal file
168
config/sso.go
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/weloe/token-go/ctx"
|
||||||
|
model2 "github.com/weloe/token-go/model"
|
||||||
|
"github.com/weloe/token-go/util"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DefaultSsoConfig(serverUrl string, notLoginView func() interface{},
|
||||||
|
doLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error),
|
||||||
|
ticketResultHandle func(o1 string, s string) (interface{}, error),
|
||||||
|
sendHttp func(url string) (string, error)) *SsoConfig {
|
||||||
|
if notLoginView == nil {
|
||||||
|
notLoginView = func() interface{} {
|
||||||
|
return "not logged in to the SSO-Server"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &SsoConfig{
|
||||||
|
Mode: "",
|
||||||
|
TicketTimeout: 60 * 5,
|
||||||
|
AllowUrl: "*",
|
||||||
|
IsSlo: true,
|
||||||
|
IsHttp: false,
|
||||||
|
Client: "",
|
||||||
|
AuthUrl: "/sso/auth",
|
||||||
|
CheckTicketUrl: "/sso/checkTicket",
|
||||||
|
GetDataUrl: "/sso/getData",
|
||||||
|
UserInfoUrl: "/sso/userInfo",
|
||||||
|
SloUrl: "/sso/signOut",
|
||||||
|
SsoLogoutCall: "",
|
||||||
|
ServerUrl: serverUrl,
|
||||||
|
NotLoginView: notLoginView,
|
||||||
|
DoLoginHandle: doLoginHandle,
|
||||||
|
TicketResultHandle: ticketResultHandle,
|
||||||
|
SendHttp: sendHttp,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSsoConfig(options *SsoOptions) (*SsoConfig, error) {
|
||||||
|
if options == nil {
|
||||||
|
options = &SsoOptions{}
|
||||||
|
}
|
||||||
|
if options.TicketTimeout == 0 {
|
||||||
|
options.TicketTimeout = 60 * 5
|
||||||
|
}
|
||||||
|
if options.AllowUrl == "" {
|
||||||
|
options.AllowUrl = "*"
|
||||||
|
}
|
||||||
|
if options.AuthUrl == "" {
|
||||||
|
options.AuthUrl = "/sso/auth"
|
||||||
|
}
|
||||||
|
if options.CheckTicketUrl == "" {
|
||||||
|
options.CheckTicketUrl = "/sso/checkTicket"
|
||||||
|
}
|
||||||
|
if options.GetDataUrl == "" {
|
||||||
|
options.GetDataUrl = "/sso/getData"
|
||||||
|
}
|
||||||
|
if options.UserInfoUrl == "" {
|
||||||
|
options.UserInfoUrl = "/sso/userInfo"
|
||||||
|
}
|
||||||
|
if options.SloUrl == "" {
|
||||||
|
options.SloUrl = "/sso/signout"
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.NotLoginView == nil {
|
||||||
|
options.NotLoginView = func() interface{} {
|
||||||
|
return "not logged in to the SSO-Server"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.DoLoginHandle == nil {
|
||||||
|
options.DoLoginHandle = func(name string, pwd string, ctx ctx.Context) (interface{}, error) {
|
||||||
|
return model2.Error(), errors.New("SsoConfig.DoLoginHandle is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.IsHttp && options.SendHttp == nil {
|
||||||
|
return nil, errors.New("please config SSO SentHttp")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SsoConfig{
|
||||||
|
Mode: options.Mode,
|
||||||
|
TicketTimeout: options.TicketTimeout,
|
||||||
|
AllowUrl: options.AllowUrl,
|
||||||
|
IsSlo: options.IsSlo,
|
||||||
|
IsHttp: options.IsHttp,
|
||||||
|
Client: options.Client,
|
||||||
|
AuthUrl: options.AuthUrl,
|
||||||
|
CheckTicketUrl: options.CheckTicketUrl,
|
||||||
|
GetDataUrl: options.GetDataUrl,
|
||||||
|
UserInfoUrl: options.UserInfoUrl,
|
||||||
|
SloUrl: options.SloUrl,
|
||||||
|
SsoLogoutCall: options.SsoLogoutCall,
|
||||||
|
ServerUrl: options.ServerUrl,
|
||||||
|
NotLoginView: options.NotLoginView,
|
||||||
|
DoLoginHandle: options.DoLoginHandle,
|
||||||
|
TicketResultHandle: options.TicketResultHandle,
|
||||||
|
SendHttp: options.SendHttp,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type SsoConfig struct {
|
||||||
|
|
||||||
|
// Mode sso mode
|
||||||
|
Mode string
|
||||||
|
// TicketTimeout ticket timeout
|
||||||
|
TicketTimeout int64
|
||||||
|
// AllowUrl All allowed authorization callback addresses, separated by ','
|
||||||
|
AllowUrl string
|
||||||
|
IsSlo bool
|
||||||
|
IsHttp bool
|
||||||
|
|
||||||
|
// SSO-Client current client name
|
||||||
|
Client string
|
||||||
|
// SSO-Server auth url
|
||||||
|
AuthUrl string
|
||||||
|
// SSO-Server check ticket url
|
||||||
|
CheckTicketUrl string
|
||||||
|
GetDataUrl string
|
||||||
|
UserInfoUrl string
|
||||||
|
SloUrl string
|
||||||
|
SsoLogoutCall string
|
||||||
|
|
||||||
|
ServerUrl string
|
||||||
|
|
||||||
|
/**
|
||||||
|
sso callback func
|
||||||
|
*/
|
||||||
|
// NotLoginView
|
||||||
|
NotLoginView func() interface{}
|
||||||
|
|
||||||
|
// DoLoginHandle login func
|
||||||
|
DoLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error)
|
||||||
|
|
||||||
|
// TicketResultHandle called each time the result of the validation ticket is obtained from the SSO-Server
|
||||||
|
TicketResultHandle func(loginId string, back string) (interface{}, error)
|
||||||
|
|
||||||
|
// SendHttp sent http
|
||||||
|
SendHttp func(url string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpliceAuthUrl return Server-side single sign-on authorization address
|
||||||
|
func (c *SsoConfig) SpliceAuthUrl() string {
|
||||||
|
return util.SpliceUrl(c.ServerUrl, c.AuthUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpliceCheckTicketUrl return ticket verification address on the server side
|
||||||
|
func (c *SsoConfig) SpliceCheckTicketUrl() string {
|
||||||
|
return util.SpliceUrl(c.ServerUrl, c.CheckTicketUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SsoConfig) SpliceGetDataUrl() string {
|
||||||
|
return util.SpliceUrl(c.ServerUrl, c.GetDataUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SsoConfig) SpliceUserInfoUrl() string {
|
||||||
|
return util.SpliceUrl(c.ServerUrl, c.UserInfoUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SsoConfig) SpliceSloUrl() string {
|
||||||
|
return util.SpliceUrl(c.ServerUrl, c.SloUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAllow set allow callback url
|
||||||
|
func (c *SsoConfig) SetAllow(urls ...string) *SsoConfig {
|
||||||
|
c.AllowUrl = strings.Join(urls, ",")
|
||||||
|
return c
|
||||||
|
}
|
52
config/sso_options.go
Normal file
52
config/sso_options.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "github.com/weloe/token-go/ctx"
|
||||||
|
|
||||||
|
// SsoOptions new SsoConfig options.
|
||||||
|
type SsoOptions struct {
|
||||||
|
CookieDomain string
|
||||||
|
|
||||||
|
// Mode sso mode
|
||||||
|
Mode string
|
||||||
|
// TicketTimeout ticket timeout
|
||||||
|
TicketTimeout int64
|
||||||
|
// AllowUrl All allowed authorization callback addresses, separated by ','
|
||||||
|
AllowUrl string
|
||||||
|
IsSlo bool
|
||||||
|
IsHttp bool
|
||||||
|
|
||||||
|
// SSO-Client current client name
|
||||||
|
Client string
|
||||||
|
// SSO-Server auth url
|
||||||
|
AuthUrl string
|
||||||
|
// SSO-Server check ticket url
|
||||||
|
CheckTicketUrl string
|
||||||
|
GetDataUrl string
|
||||||
|
UserInfoUrl string
|
||||||
|
SloUrl string
|
||||||
|
SsoLogoutCall string
|
||||||
|
|
||||||
|
ServerUrl string
|
||||||
|
|
||||||
|
/**
|
||||||
|
sso callback func
|
||||||
|
*/
|
||||||
|
// NotLoginView
|
||||||
|
NotLoginView func() interface{}
|
||||||
|
|
||||||
|
// DoLoginHandle login func
|
||||||
|
DoLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error)
|
||||||
|
|
||||||
|
// TicketResultHandle called each time the result of the validation ticket is obtained from the SSO-Server
|
||||||
|
TicketResultHandle func(loginId string, back string) (interface{}, error)
|
||||||
|
|
||||||
|
// SendHttp sent http
|
||||||
|
SendHttp func(url string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignOptions SignConfig options
|
||||||
|
type SignOptions struct {
|
||||||
|
SecretKey string
|
||||||
|
TimeStampDisparity int64
|
||||||
|
IsCheckNonce bool
|
||||||
|
}
|
11
constant/sso.go
Normal file
11
constant/sso.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
const (
|
||||||
|
SLO_CALLBACK_SET_KEY = "SLO_CALLBACK_SET_KEY_"
|
||||||
|
|
||||||
|
SELF = "self"
|
||||||
|
|
||||||
|
MODE_SIMPLE = "simple"
|
||||||
|
|
||||||
|
NOT_HANDLE = "{\"msg\": \"not handle\"}"
|
||||||
|
)
|
39
model/http_result.go
Normal file
39
model/http_result.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
const (
|
||||||
|
SUCCESS = 1
|
||||||
|
ERROR = 0
|
||||||
|
)
|
||||||
|
|
||||||
|
// Result wrap the http request result.
|
||||||
|
type Result struct {
|
||||||
|
Code int
|
||||||
|
Msg string
|
||||||
|
Data interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Ok() *Result {
|
||||||
|
return &Result{
|
||||||
|
Code: SUCCESS,
|
||||||
|
Msg: "success",
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Error() *Result {
|
||||||
|
return &Result{
|
||||||
|
Code: -1,
|
||||||
|
Msg: "error",
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Result) SetData(data interface{}) *Result {
|
||||||
|
r.Data = data
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Result) SetMsg(msg string) *Result {
|
||||||
|
r.Msg = msg
|
||||||
|
return r
|
||||||
|
}
|
34
sso/api.go
Normal file
34
sso/api.go
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
// ApiName sso api name, used to dispatcher request.
|
||||||
|
type ApiName struct {
|
||||||
|
// sso-server auth url
|
||||||
|
SsoAuth string
|
||||||
|
// sso-server rest api login url
|
||||||
|
SsoDoLogin string
|
||||||
|
// sso-server check ticket url
|
||||||
|
SsoCheckTicket string
|
||||||
|
// sso-server get user info url
|
||||||
|
SsoUserInfo string
|
||||||
|
// sso-server single logout url
|
||||||
|
SsoSignout string
|
||||||
|
// sso-client login url
|
||||||
|
SsoLogin string
|
||||||
|
// sso-client single logout url
|
||||||
|
SsoLogout string
|
||||||
|
// sso-client logout callback url
|
||||||
|
SsoLogoutCall string
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultApiName() *ApiName {
|
||||||
|
return &ApiName{
|
||||||
|
SsoAuth: "/sso/auth",
|
||||||
|
SsoDoLogin: "/sso/doLogin",
|
||||||
|
SsoCheckTicket: "/sso/checkTicket",
|
||||||
|
SsoUserInfo: "/sso/userInfo",
|
||||||
|
SsoSignout: "/sso/signout",
|
||||||
|
SsoLogin: "/sso/login",
|
||||||
|
SsoLogout: "/sso/logout",
|
||||||
|
SsoLogoutCall: "/sso/logoutCall",
|
||||||
|
}
|
||||||
|
}
|
39
sso/param.go
Normal file
39
sso/param.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
// ParamName http request param name.
|
||||||
|
type ParamName struct {
|
||||||
|
Redirect string
|
||||||
|
Ticket string
|
||||||
|
Back string
|
||||||
|
Mode string
|
||||||
|
LoginId string
|
||||||
|
Client string
|
||||||
|
SsoLogoutCall string
|
||||||
|
Name string
|
||||||
|
Pwd string
|
||||||
|
|
||||||
|
//=== sign param
|
||||||
|
|
||||||
|
TimeStamp string
|
||||||
|
Nonce string
|
||||||
|
Sign string
|
||||||
|
SecretKet string
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultParamName() *ParamName {
|
||||||
|
return &ParamName{
|
||||||
|
Redirect: "redirect",
|
||||||
|
Ticket: "ticket",
|
||||||
|
Back: "back",
|
||||||
|
Mode: "mode",
|
||||||
|
LoginId: "loginId",
|
||||||
|
Client: "client",
|
||||||
|
SsoLogoutCall: "ssoLogoutCall",
|
||||||
|
Name: "name",
|
||||||
|
Pwd: "pwd",
|
||||||
|
TimeStamp: "timestamp",
|
||||||
|
Nonce: "nonce",
|
||||||
|
Sign: "sign",
|
||||||
|
SecretKet: "key",
|
||||||
|
}
|
||||||
|
}
|
108
sso/sso.go
Normal file
108
sso/sso.go
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
tokenGo "github.com/weloe/token-go"
|
||||||
|
"github.com/weloe/token-go/config"
|
||||||
|
"github.com/weloe/token-go/constant"
|
||||||
|
"github.com/weloe/token-go/ctx"
|
||||||
|
"github.com/weloe/token-go/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Options construct options
|
||||||
|
type Options struct {
|
||||||
|
SsoOptions *config.SsoOptions
|
||||||
|
SignOptions *config.SignOptions
|
||||||
|
Enforcer tokenGo.IEnforcer
|
||||||
|
}
|
||||||
|
|
||||||
|
type SsoEnforcer struct {
|
||||||
|
apiName *ApiName
|
||||||
|
paramName *ParamName
|
||||||
|
config *config.SsoConfig
|
||||||
|
signConfig *config.SignConfig
|
||||||
|
enforcer tokenGo.IEnforcer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSsoEnforcer create sso enforcer.
|
||||||
|
// If the available field in the parameter is empty, use the default value,
|
||||||
|
// if the required field is empty, return nil and error.
|
||||||
|
func NewSsoEnforcer(options *Options) (*SsoEnforcer, error) {
|
||||||
|
if options.Enforcer == nil {
|
||||||
|
return nil, errors.New("Options.Enforcer can not be nil")
|
||||||
|
}
|
||||||
|
if options.SsoOptions.CookieDomain != "" {
|
||||||
|
options.Enforcer.GetTokenConfig().CookieConfig.Domain = options.SsoOptions.CookieDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
ssoConfig, err := config.NewSsoConfig(options.SsoOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
signConfig, err := config.NewSignConfig(options.SignOptions)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SsoEnforcer{
|
||||||
|
apiName: DefaultApiName(),
|
||||||
|
paramName: DefaultParamName(),
|
||||||
|
config: ssoConfig,
|
||||||
|
signConfig: signConfig,
|
||||||
|
enforcer: options.Enforcer,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) SetApi(apiName *ApiName) {
|
||||||
|
s.apiName = apiName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) GetApi() *ApiName {
|
||||||
|
return s.apiName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) SetParamName(paramName *ParamName) {
|
||||||
|
s.paramName = paramName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) GetParamName() *ParamName {
|
||||||
|
return s.paramName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) SetSsoConfig(config *config.SsoConfig) {
|
||||||
|
s.config = config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) GetSsoConfig() *config.SsoConfig {
|
||||||
|
return s.config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) SetSignConfig(config *config.SignConfig) {
|
||||||
|
s.signConfig = config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) GetSignConfig() *config.SignConfig {
|
||||||
|
return s.signConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ssoLogoutBack single-logout callback, redirect to ParamName.Back url.
|
||||||
|
// If http request has back param and value is SELF, redirect to previous page,
|
||||||
|
// if http request back param and value is validated url, redirect to url,
|
||||||
|
// else return model.Ok() directly.
|
||||||
|
func (s *SsoEnforcer) ssoLogoutBack(ctx ctx.Context) (interface{}, error) {
|
||||||
|
paramName := s.paramName
|
||||||
|
request := ctx.Request()
|
||||||
|
response := ctx.Response()
|
||||||
|
back := request.Query(paramName.Back)
|
||||||
|
if back != "" {
|
||||||
|
if back == constant.SELF {
|
||||||
|
return "<script>if(document.referrer != location.href){ location.replace(document.referrer || '/'); }</script>", nil
|
||||||
|
}
|
||||||
|
response.Redirect(back)
|
||||||
|
return nil, nil
|
||||||
|
} else {
|
||||||
|
// back is nil
|
||||||
|
return model.Ok().SetMsg("back is nil, not redirect"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
206
sso/sso_client_api.go
Normal file
206
sso/sso_client_api.go
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/weloe/token-go/ctx"
|
||||||
|
"github.com/weloe/token-go/model"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
=========processor SSO-Client api
|
||||||
|
*/
|
||||||
|
|
||||||
|
// SsoClientLogin SSO-Client login.
|
||||||
|
func (s *SsoEnforcer) SsoClientLogin(ctx ctx.Context) (interface{}, error) {
|
||||||
|
request := ctx.Request()
|
||||||
|
response := ctx.Response()
|
||||||
|
paramName := s.paramName
|
||||||
|
apiName := s.apiName
|
||||||
|
|
||||||
|
// get back value and redirect
|
||||||
|
back := request.Query(paramName.Back)
|
||||||
|
if back == "" {
|
||||||
|
back = "/"
|
||||||
|
}
|
||||||
|
isLogin, err := s.enforcer.IsLogin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// if the current client is already logged in, there is no need to redirect to the SSO-Server
|
||||||
|
if isLogin {
|
||||||
|
response.Redirect(back)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// if isLogin == false, attempt to get ticket
|
||||||
|
ticket := request.Query(paramName.Ticket)
|
||||||
|
|
||||||
|
// if ticket == "", redirect to SSO-Server, the default path is /sso/auth
|
||||||
|
if ticket == "" {
|
||||||
|
serverAuthUrl, err := s.buildServerAuthUrl(request.UrlNoQuery(), back)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response.Redirect(serverAuthUrl)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// else ticket != "", need to log in by ticket
|
||||||
|
var loginId string
|
||||||
|
|
||||||
|
// get current path
|
||||||
|
ssoLoginUrl := apiName.SsoLogin
|
||||||
|
if s.config.IsHttp {
|
||||||
|
// use http request to SSO-Server to check ticket in SSO-Server
|
||||||
|
var ssoLogoutCall string
|
||||||
|
if s.config.IsSlo {
|
||||||
|
// get logout callback url
|
||||||
|
if s.config.SsoLogoutCall != "" {
|
||||||
|
ssoLogoutCall = s.config.SsoLogoutCall
|
||||||
|
} else if ssoLoginUrl != "" {
|
||||||
|
ssoLogoutCall = strings.ReplaceAll(request.UrlNoQuery(), ssoLoginUrl, apiName.SsoLogoutCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// send http to check
|
||||||
|
checkTicketUrl, err := s.buildCheckTicketUrl(ticket, ssoLogoutCall)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// send http request
|
||||||
|
resp, err := s.request(checkTicketUrl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.Code == model.ERROR {
|
||||||
|
return nil, errors.New("request failed: " + resp.Msg)
|
||||||
|
}
|
||||||
|
loginId = resp.Data.(string)
|
||||||
|
} else {
|
||||||
|
// use adapter to check ticket
|
||||||
|
loginId, err = s.CheckTicket(ticket)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if set callback TicketResultHandle
|
||||||
|
if s.config.TicketResultHandle != nil {
|
||||||
|
return s.config.TicketResultHandle(loginId, back)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if loginId == "", return error
|
||||||
|
if loginId == "" {
|
||||||
|
return nil, errors.New("invalid ticket: " + ticket)
|
||||||
|
}
|
||||||
|
// login
|
||||||
|
_, err = s.enforcer.Login(loginId, ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirect to back
|
||||||
|
response.Redirect(back)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SsoClientLogout SSO-Client single-logout.
|
||||||
|
func (s *SsoEnforcer) SsoClientLogout(ctx ctx.Context) (interface{}, error) {
|
||||||
|
ssoConfig := s.config
|
||||||
|
|
||||||
|
// enable single-logout and isHttp == false
|
||||||
|
if ssoConfig.IsSlo && !ssoConfig.IsHttp {
|
||||||
|
isLogin, err := s.enforcer.IsLogin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// check if you are logged in
|
||||||
|
if isLogin {
|
||||||
|
var id string
|
||||||
|
id, err = s.enforcer.GetLoginId(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// client logout
|
||||||
|
err = s.enforcer.LogoutById(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// callback
|
||||||
|
return s.ssoLogoutBack(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// enable single-logout and isHttp
|
||||||
|
if ssoConfig.IsSlo && ssoConfig.IsHttp {
|
||||||
|
isLogin, err := s.enforcer.IsLogin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !isLogin {
|
||||||
|
return s.ssoLogoutBack(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// get id
|
||||||
|
var id string
|
||||||
|
id, err = s.enforcer.GetLoginId(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// use SSO-Server single-logout api to logout
|
||||||
|
sloUrl, err := s.buildSloUrl(id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res, err := s.request(sloUrl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// check response data
|
||||||
|
if res.Code == model.SUCCESS {
|
||||||
|
login, _ := s.enforcer.IsLogin(ctx)
|
||||||
|
if login {
|
||||||
|
err := s.enforcer.Logout(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.ssoLogoutBack(ctx)
|
||||||
|
} else {
|
||||||
|
return nil, errors.New("request failed: " + res.Msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("not handle")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SsoClientLogoutCall client logout callback.
|
||||||
|
func (s *SsoEnforcer) SsoClientLogoutCall(ctx ctx.Context) (interface{}, error) {
|
||||||
|
request := ctx.Request()
|
||||||
|
loginId := request.Query(s.paramName.LoginId)
|
||||||
|
if loginId == "" {
|
||||||
|
return nil, errors.New("request param must include loginId")
|
||||||
|
}
|
||||||
|
// check request param
|
||||||
|
err := s.checkRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// logout
|
||||||
|
err = s.enforcer.Logout(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return model.Ok().SetMsg("logout callback success"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetData client build url to sent http to get data from SSO-Server.
|
||||||
|
func (s *SsoEnforcer) GetData(paramMap map[string]string) (interface{}, error) {
|
||||||
|
finalUrl, err := s.buildGetDataUrl(paramMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res, err := s.config.SendHttp(finalUrl)
|
||||||
|
return res, err
|
||||||
|
}
|
64
sso/sso_dispatcher_api.go
Normal file
64
sso/sso_dispatcher_api.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/weloe/token-go/ctx"
|
||||||
|
"github.com/weloe/token-go/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
=========dispatcher api
|
||||||
|
*/
|
||||||
|
|
||||||
|
// ServerDisPatcher dispatcher SSO-Server api, returns model.Result or string.
|
||||||
|
func (s *SsoEnforcer) ServerDisPatcher(ctx ctx.Context) interface{} {
|
||||||
|
request := ctx.Request()
|
||||||
|
apiName := s.apiName
|
||||||
|
path := request.Path()
|
||||||
|
var res interface{}
|
||||||
|
var err error
|
||||||
|
if path == apiName.SsoAuth {
|
||||||
|
res, err = s.SsoAuth(ctx)
|
||||||
|
} else if path == apiName.SsoDoLogin {
|
||||||
|
res, err = s.SsoDoLogin(ctx)
|
||||||
|
} else if path == apiName.SsoCheckTicket && s.config.IsHttp {
|
||||||
|
res, err = s.SsoCheckTicket(ctx)
|
||||||
|
} else if path == apiName.SsoSignout {
|
||||||
|
res, err = s.SsoSignOut(ctx)
|
||||||
|
} else {
|
||||||
|
return model.Error().SetMsg("not handle")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return model.Error().SetMsg(err.Error())
|
||||||
|
}
|
||||||
|
if res == nil {
|
||||||
|
return model.Ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientDispatcher dispatcher Client api, returns model.Result or string.
|
||||||
|
func (s *SsoEnforcer) ClientDispatcher(ctx ctx.Context) interface{} {
|
||||||
|
request := ctx.Request()
|
||||||
|
apiName := s.apiName
|
||||||
|
path := request.Path()
|
||||||
|
var res interface{}
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if path == apiName.SsoLogin {
|
||||||
|
res, err = s.SsoClientLogin(ctx)
|
||||||
|
} else if path == apiName.SsoLogout {
|
||||||
|
res, err = s.SsoClientLogout(ctx)
|
||||||
|
} else if path == apiName.SsoLogoutCall && s.config.IsSlo && s.config.IsHttp {
|
||||||
|
res, err = s.SsoClientLogoutCall(ctx)
|
||||||
|
} else {
|
||||||
|
return model.Error().SetMsg("not handle")
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return model.Error().SetMsg(err.Error())
|
||||||
|
}
|
||||||
|
if res == nil {
|
||||||
|
return model.Ok()
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
494
sso/sso_internal_api.go
Normal file
494
sso/sso_internal_api.go
Normal file
@@ -0,0 +1,494 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/weloe/token-go/constant"
|
||||||
|
"github.com/weloe/token-go/ctx"
|
||||||
|
"github.com/weloe/token-go/model"
|
||||||
|
"github.com/weloe/token-go/util"
|
||||||
|
"log"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
=========internal api
|
||||||
|
*/
|
||||||
|
|
||||||
|
// checkRequest check request param timestamp,nonce,sign.
|
||||||
|
func (s *SsoEnforcer) checkRequest(request ctx.Request) error {
|
||||||
|
timestamp := request.Query(s.paramName.TimeStamp)
|
||||||
|
nonce := request.Query(s.paramName.Nonce)
|
||||||
|
sign := request.Query(s.paramName.Sign)
|
||||||
|
err := s.checkTimeStamp(timestamp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.signConfig.IsCheckNonce {
|
||||||
|
err = s.checkNonce(nonce)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = s.checkSign(timestamp, nonce, sign)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTicket create ticket by account-id.
|
||||||
|
func (s *SsoEnforcer) CreateTicket(loginId string, client string) (string, error) {
|
||||||
|
// create random string ticket
|
||||||
|
ticket, err := util.GenerateRandomString64()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// save ticket-id+client
|
||||||
|
err = s.saveTicket(ticket, loginId, client)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// save id-ticket
|
||||||
|
err = s.saveTicketIndex(ticket, loginId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ticket, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLoginId get loginId by ticket.
|
||||||
|
func (s *SsoEnforcer) GetLoginId(ticket string) string {
|
||||||
|
if ticket == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
loginId := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
|
||||||
|
if loginId != "" && strings.Contains(loginId, ",") {
|
||||||
|
split := strings.Split(loginId, ",")
|
||||||
|
loginId = split[0]
|
||||||
|
}
|
||||||
|
return loginId
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTicket get ticket by loginId.
|
||||||
|
func (s *SsoEnforcer) GetTicket(loginId string) string {
|
||||||
|
if loginId == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s.enforcer.GetAdapter().GetStr(s.spliceTicketIndexKey(loginId))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTicket use config.Client to check ticket,return loginId.
|
||||||
|
func (s *SsoEnforcer) CheckTicket(ticket string) (string, error) {
|
||||||
|
return s.CheckTicketByClient(ticket, s.config.Client)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckTicketByClient check ticket by pointing client,return loginId.
|
||||||
|
func (s *SsoEnforcer) CheckTicketByClient(ticket string, client string) (string, error) {
|
||||||
|
id := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
|
||||||
|
if id == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ticketClient string
|
||||||
|
if strings.Contains(id, ",") {
|
||||||
|
split := strings.Split(id, ",")
|
||||||
|
id = split[0]
|
||||||
|
ticketClient = split[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if client != "" && client != ticketClient {
|
||||||
|
return "", fmt.Errorf("the ticket does not belong to the client, client: %v, ticket: %v", client, ticket)
|
||||||
|
}
|
||||||
|
err := s.deleteTicket(ticket)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
err = s.deleteTicketIndex(id)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckRedirectUrl check redirectUrl.
|
||||||
|
func (s *SsoEnforcer) CheckRedirectUrl(url string) error {
|
||||||
|
if !util.IsValidUrl(url) {
|
||||||
|
return fmt.Errorf("invalid redirect url: %v", url)
|
||||||
|
}
|
||||||
|
index := strings.Index(url, "?")
|
||||||
|
if index != -1 {
|
||||||
|
url = url[0:index]
|
||||||
|
}
|
||||||
|
allowUrls := strings.Split(s.GetAllowUrl(), ",")
|
||||||
|
|
||||||
|
if !util.HasUrl(allowUrls, url) {
|
||||||
|
return fmt.Errorf("illegal redirect url: %v", url)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterSloCallbackUrl register the URL of the single logout callback for the account id.
|
||||||
|
func (s *SsoEnforcer) RegisterSloCallbackUrl(loginId string, sloCallbackUrl string) error {
|
||||||
|
if loginId == "" || sloCallbackUrl == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
session := s.enforcer.GetSession(loginId)
|
||||||
|
// splice session id
|
||||||
|
sessionId := s.enforcer.GetTokenConfig().TokenName + ":" + s.enforcer.GetType() + ":session:" + loginId
|
||||||
|
if session == nil {
|
||||||
|
session = model.NewSession(sessionId, "account-session", loginId)
|
||||||
|
}
|
||||||
|
value := session.Get(constant.SLO_CALLBACK_SET_KEY)
|
||||||
|
|
||||||
|
var v []string
|
||||||
|
if value != nil {
|
||||||
|
sv, ok := value.([]string)
|
||||||
|
v = sv
|
||||||
|
if !ok {
|
||||||
|
return errors.New("session SLO_CALLBACK_SET_KEY_ data convert into []string failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v = util.AppendStr(v, sloCallbackUrl)
|
||||||
|
|
||||||
|
session.Set(sessionId, v)
|
||||||
|
// update session
|
||||||
|
err := s.enforcer.UpdateSession(loginId, session)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ssoSignOutById single sign-out of the specified account.
|
||||||
|
// Use loginId to get single sign-out urls from session.
|
||||||
|
func (s *SsoEnforcer) ssoSignOutById(loginId string) error {
|
||||||
|
// if loginId is not logged, return error
|
||||||
|
session := s.enforcer.GetSession(loginId)
|
||||||
|
if session == nil {
|
||||||
|
return errors.New("this loginId is not logged in")
|
||||||
|
}
|
||||||
|
value := session.Get(constant.SLO_CALLBACK_SET_KEY)
|
||||||
|
if value == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
urls, ok := value.([]string)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("convert into []string failed")
|
||||||
|
}
|
||||||
|
// range urls to make client logout
|
||||||
|
for _, url := range urls {
|
||||||
|
// join url
|
||||||
|
newUrl, err := s.joinLoginIdAndSign(url, loginId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// sent http
|
||||||
|
_, err = s.config.SendHttp(newUrl)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// server logout
|
||||||
|
return s.enforcer.LogoutById(loginId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// joinLoginIdAndSign splice the loginId to the url, and stitching parameters such as sign.
|
||||||
|
func (s *SsoEnforcer) joinLoginIdAndSign(url string, id string) (string, error) {
|
||||||
|
nonce, err := util.GenerateRandomString32()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||||
|
sign, err := s.createSign(timestamp, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
str := url + "?" + s.paramName.LoginId + "=" + id + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.Nonce + "=" + nonce + "&" + s.paramName.Sign + "=" + sign
|
||||||
|
return str, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildServerAuthUrl SSO-Client build SSO-Server single sign-on url.
|
||||||
|
func (s *SsoEnforcer) buildServerAuthUrl(clientLoginUrl string, back string) (string, error) {
|
||||||
|
if clientLoginUrl == "" {
|
||||||
|
return "", errors.New("arg[0] clientLoginUrl can not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// get server auth url
|
||||||
|
authUrl := s.config.SpliceAuthUrl()
|
||||||
|
|
||||||
|
client := s.config.Client
|
||||||
|
|
||||||
|
if client != "" {
|
||||||
|
authUrl = util.AddQuery(authUrl, s.paramName.Client, client)
|
||||||
|
}
|
||||||
|
|
||||||
|
// splice back url
|
||||||
|
if back != "" {
|
||||||
|
back = util.Encode(back)
|
||||||
|
clientLoginUrl = util.AddQuery(clientLoginUrl, s.paramName.Back, back)
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.AddQuery(authUrl, s.paramName.Redirect, clientLoginUrl), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildRedirectUrl the server gives the redirectUrl of the ticket to the client.
|
||||||
|
// Check redirect url, delete old ticket of loginId and create new ticket, then return url with new ticket.
|
||||||
|
func (s *SsoEnforcer) buildRedirectUrl(loginId string, client string, redirect string) (string, error) {
|
||||||
|
// check redirect url
|
||||||
|
err := s.CheckRedirectUrl(redirect)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// delete old ticket
|
||||||
|
err = s.deleteTicket(s.GetTicket(loginId))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
// create new ticket
|
||||||
|
ticket, err := s.CreateTicket(loginId, client)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// return redirect + "?" + s.paramName.Ticket + "=" + ticket, nil
|
||||||
|
return util.AddQuery(s.encodeBackParam(redirect), s.paramName.Ticket, ticket), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeBackParam find back param from url, and encode back param.
|
||||||
|
func (s *SsoEnforcer) encodeBackParam(url string) string {
|
||||||
|
// get back location
|
||||||
|
index := strings.Index(url, "?"+s.paramName.Back+"=")
|
||||||
|
if index == -1 {
|
||||||
|
index = strings.Index(url, "&"+s.paramName.Back+"=")
|
||||||
|
if index == -1 {
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// encode
|
||||||
|
length := len(s.paramName.Back) + 2
|
||||||
|
back := url[index+length:]
|
||||||
|
back = util.Encode(back)
|
||||||
|
|
||||||
|
// update back
|
||||||
|
url = url[:index+length] + back
|
||||||
|
return url
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildCheckTicketUrl build to check ticket.
|
||||||
|
func (s *SsoEnforcer) buildCheckTicketUrl(ticket string, ssoLogoutCallUrl string) (string, error) {
|
||||||
|
if ticket == "" {
|
||||||
|
return "", errors.New("buildCheckTicketUrl() ticket can not be nil")
|
||||||
|
}
|
||||||
|
checkTicketUrl := s.config.SpliceCheckTicketUrl()
|
||||||
|
client := s.config.Client
|
||||||
|
paramMap := make(map[string]string)
|
||||||
|
if client != "" {
|
||||||
|
paramMap[s.paramName.Client] = client
|
||||||
|
}
|
||||||
|
paramMap[s.paramName.Ticket] = ticket
|
||||||
|
if ssoLogoutCallUrl != "" {
|
||||||
|
paramMap[s.paramName.SsoLogoutCall] = ssoLogoutCallUrl
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.AddQueryMap(checkTicketUrl, paramMap), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSloUrl build single-logout url.
|
||||||
|
func (s *SsoEnforcer) buildSloUrl(loginId string) (string, error) {
|
||||||
|
sloUrl := s.config.SpliceSloUrl()
|
||||||
|
url, err := s.joinLoginIdAndSign(sloUrl, loginId)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return url, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildGetDataUrl build getData url with sign,timestamp,nonce.
|
||||||
|
func (s *SsoEnforcer) buildGetDataUrl(paramMap map[string]string) (string, error) {
|
||||||
|
getDataUrl := s.config.SpliceGetDataUrl()
|
||||||
|
return s.buildCustomPathUrl(getDataUrl, paramMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildCustomPathUrl add paramMap to path.
|
||||||
|
func (s *SsoEnforcer) buildCustomPathUrl(path string, paramMap map[string]string) (string, error) {
|
||||||
|
u := path
|
||||||
|
|
||||||
|
if !strings.HasPrefix(u, "http") {
|
||||||
|
serverUrl := s.config.ServerUrl
|
||||||
|
if serverUrl == "" {
|
||||||
|
return "", errors.New("please set sso serverUrl")
|
||||||
|
}
|
||||||
|
u = util.SpliceUrl(serverUrl, path)
|
||||||
|
}
|
||||||
|
// sign map
|
||||||
|
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||||
|
paramMap[s.paramName.TimeStamp] = timestamp
|
||||||
|
nonce, err := util.GenerateRandomString32()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
paramMap[s.paramName.Nonce] = nonce
|
||||||
|
// create sign
|
||||||
|
sign, err := s.createSign(timestamp, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
paramMap[s.paramName.Sign] = sign
|
||||||
|
finalUrl := util.AddQueryMap(u, paramMap)
|
||||||
|
return finalUrl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// request send http request and use json.Unmarshal to converted to *model.Result.
|
||||||
|
func (s *SsoEnforcer) request(url string) (*model.Result, error) {
|
||||||
|
resp, err := s.config.SendHttp(url)
|
||||||
|
|
||||||
|
log.Printf("http request response: %s", resp)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result := &model.Result{}
|
||||||
|
err = json.Unmarshal([]byte(resp), result)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) GetAllowUrl() string {
|
||||||
|
return s.config.AllowUrl
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveTicket save ticket-id+client.
|
||||||
|
func (s *SsoEnforcer) saveTicket(ticket string, loginId string, client string) error {
|
||||||
|
value := loginId
|
||||||
|
if client != "" {
|
||||||
|
value += "," + client
|
||||||
|
}
|
||||||
|
ticketTimeout := s.config.TicketTimeout
|
||||||
|
return s.enforcer.GetAdapter().SetStr(s.spliceTicketSaveKey(ticket), value, ticketTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveTicketIndex save id-ticket.
|
||||||
|
func (s *SsoEnforcer) saveTicketIndex(ticket string, id string) error {
|
||||||
|
ticketTimeout := s.config.TicketTimeout
|
||||||
|
return s.enforcer.GetAdapter().SetStr(s.spliceTicketIndexKey(id), ticket, ticketTimeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// spliceTicketSaveKey splice ticket-id key.
|
||||||
|
func (s *SsoEnforcer) spliceTicketSaveKey(ticket string) string {
|
||||||
|
return s.enforcer.GetTokenConfig().TokenName + ":ticket:" + ticket
|
||||||
|
}
|
||||||
|
|
||||||
|
// spliceTicketIndexKey splice id-ticket key.
|
||||||
|
func (s *SsoEnforcer) spliceTicketIndexKey(id string) string {
|
||||||
|
return s.enforcer.GetTokenConfig().TokenName + ":id-ticket:" + id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) deleteTicket(ticket string) error {
|
||||||
|
if ticket == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketSaveKey(ticket))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) deleteTicketIndex(id string) error {
|
||||||
|
if id == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketIndexKey(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
|
||||||
|
func (s *SsoEnforcer) checkTimeStamp(timestamp string) error {
|
||||||
|
parseInt, err := strconv.ParseInt(timestamp, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !s.isValidTimeStamp(parseInt) {
|
||||||
|
return errors.New("timestamp is out of allowed range: " + timestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
|
||||||
|
func (s *SsoEnforcer) isValidTimeStamp(timestamp int64) bool {
|
||||||
|
allowDisparity := s.signConfig.TimeStampDisparity
|
||||||
|
nowDisparity := time.Now().UnixMilli() - timestamp
|
||||||
|
|
||||||
|
return allowDisparity == 1 || nowDisparity <= allowDisparity
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkNonce the same nonce can only be verified once.
|
||||||
|
func (s *SsoEnforcer) checkNonce(nonce string) error {
|
||||||
|
if nonce == "" {
|
||||||
|
return errors.New("nonce is nil")
|
||||||
|
}
|
||||||
|
// if nonce exists in adapter
|
||||||
|
if !s.isValidNonce(nonce) {
|
||||||
|
return errors.New("the nonce has been used: " + nonce)
|
||||||
|
}
|
||||||
|
// set nonce after nonce
|
||||||
|
err := s.enforcer.GetAdapter().SetStr(s.spliceNonceSaveKey(nonce), nonce, s.signConfig.GetSaveNonceExpire()*2+2)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidNonce determine random string, if not exist in adapter, return true.
|
||||||
|
func (s *SsoEnforcer) isValidNonce(nonce string) bool {
|
||||||
|
if nonce == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key := s.spliceNonceSaveKey(nonce)
|
||||||
|
// if not exist in adapter, return true.
|
||||||
|
return s.enforcer.GetAdapter().GetStr(key) == ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SsoEnforcer) checkSign(timestamp string, nonce string, sign string) error {
|
||||||
|
valid, err := s.isValidSign(timestamp, nonce, sign)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return errors.New("invalid sign: " + sign)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidSign use timestamp,nonce,sign to createSign to compare.
|
||||||
|
func (s *SsoEnforcer) isValidSign(timestamp string, nonce string, sign string) (bool, error) {
|
||||||
|
recreateSign, err := s.createSign(timestamp, nonce)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return recreateSign == sign, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSign use util.MD5() to generate str.
|
||||||
|
func (s *SsoEnforcer) createSign(timestamp string, nonce string) (string, error) {
|
||||||
|
secretKey := s.signConfig.SecretKey
|
||||||
|
if secretKey == "" {
|
||||||
|
return "", errors.New("please check SignConfig.SecretKey, SecretKey can not be nil")
|
||||||
|
}
|
||||||
|
str := s.paramName.Nonce + "=" + nonce + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.SecretKet + "=" + secretKey
|
||||||
|
|
||||||
|
return util.MD5(str), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// spliceNonceSaveKey splice nonce store key.
|
||||||
|
func (s *SsoEnforcer) spliceNonceSaveKey(nonce string) string {
|
||||||
|
return s.enforcer.GetTokenConfig().TokenName + ":sign:nonce:" + nonce
|
||||||
|
}
|
135
sso/sso_server_api.go
Normal file
135
sso/sso_server_api.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/weloe/token-go/constant"
|
||||||
|
"github.com/weloe/token-go/ctx"
|
||||||
|
"github.com/weloe/token-go/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
=========processor SSO-Server api
|
||||||
|
*/
|
||||||
|
|
||||||
|
// SsoAuth SSO-Server: auth.
|
||||||
|
func (s *SsoEnforcer) SsoAuth(ctx ctx.Context) (interface{}, error) {
|
||||||
|
request := ctx.Request()
|
||||||
|
response := ctx.Response()
|
||||||
|
|
||||||
|
isLogin, err := s.enforcer.IsLogin(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// if you have not logged in to the SSO-Server, need to log in first
|
||||||
|
if !isLogin {
|
||||||
|
return s.config.NotLoginView(), nil
|
||||||
|
}
|
||||||
|
// if you have logged, check the mode
|
||||||
|
mode := request.Query(s.paramName.Mode)
|
||||||
|
|
||||||
|
redirect := request.Query(s.paramName.Redirect)
|
||||||
|
|
||||||
|
// if mode == simple, redirect to client directly
|
||||||
|
if mode == constant.MODE_SIMPLE {
|
||||||
|
err = s.CheckRedirectUrl(redirect)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response.Redirect(redirect)
|
||||||
|
return nil, nil
|
||||||
|
} else {
|
||||||
|
// mode = ticket, redirect to client login with new ticket
|
||||||
|
id, err := s.enforcer.GetLoginId(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
redirectUrl, err := s.buildRedirectUrl(id, request.Query(s.paramName.Client), redirect)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response.Redirect(redirectUrl)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SsoDoLogin SSO-Server: rest login api.
|
||||||
|
func (s *SsoEnforcer) SsoDoLogin(ctx ctx.Context) (interface{}, error) {
|
||||||
|
request := ctx.Request()
|
||||||
|
paramName := s.paramName
|
||||||
|
if s.config.DoLoginHandle == nil {
|
||||||
|
return nil, errors.New("SsoConfig.DoLoginHandle is nil")
|
||||||
|
}
|
||||||
|
resp, err := s.config.DoLoginHandle(request.Query(paramName.Name), request.Query(paramName.Pwd), ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SsoCheckTicket SSO-Server: check ticket to get loginId, returns loginId
|
||||||
|
func (s *SsoEnforcer) SsoCheckTicket(ctx ctx.Context) (interface{}, error) {
|
||||||
|
paramName := s.paramName
|
||||||
|
request := ctx.Request()
|
||||||
|
client := request.Query(paramName.Client)
|
||||||
|
ticket := request.Query(paramName.Ticket)
|
||||||
|
if ticket == "" {
|
||||||
|
return nil, errors.New("ticket can not be nil")
|
||||||
|
}
|
||||||
|
sloCallback := request.Query(paramName.SsoLogoutCall)
|
||||||
|
|
||||||
|
// check ticket
|
||||||
|
loginId, err := s.CheckTicketByClient(ticket, client)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// register single sign out callback url
|
||||||
|
err = s.RegisterSloCallbackUrl(loginId, sloCallback)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginId == "" {
|
||||||
|
return nil, errors.New("invalid ticket: " + ticket)
|
||||||
|
}
|
||||||
|
return model.Ok().SetData(loginId), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SsoSignOut SSO-Server: single sign-out.
|
||||||
|
func (s *SsoEnforcer) SsoSignOut(ctx ctx.Context) (interface{}, error) {
|
||||||
|
request := ctx.Request()
|
||||||
|
paramName := s.paramName
|
||||||
|
|
||||||
|
// if enable single sign-out and request param has loginId
|
||||||
|
reqLoginId := request.Query(paramName.LoginId)
|
||||||
|
if s.config.IsSlo && reqLoginId == "" {
|
||||||
|
loginId, err := s.enforcer.GetLoginId(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if loginId != "" {
|
||||||
|
err = s.ssoSignOutById(loginId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// callback
|
||||||
|
return s.ssoLogoutBack(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if enable http,single sign-out and request param has loginId
|
||||||
|
if s.config.IsHttp && s.config.IsSlo && reqLoginId != "" {
|
||||||
|
err := s.checkRequest(request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Use loginId to get single sign-out urls from session, traverse the urls to send the request to notify client
|
||||||
|
err = s.ssoSignOutById(reqLoginId)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return model.Ok().SetMsg("sso sign-out success"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("not handle")
|
||||||
|
}
|
103
sso/sso_test.go
Normal file
103
sso/sso_test.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package sso
|
||||||
|
|
||||||
|
import (
|
||||||
|
tokenGo "github.com/weloe/token-go"
|
||||||
|
"github.com/weloe/token-go/config"
|
||||||
|
"github.com/weloe/token-go/ctx"
|
||||||
|
"github.com/weloe/token-go/model"
|
||||||
|
"github.com/weloe/token-go/util"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SendGetRequest(url string) (string, error) {
|
||||||
|
return util.SendGetRequest(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSsoServerEnforcer(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
// use default adapter
|
||||||
|
adapter := tokenGo.NewDefaultAdapter()
|
||||||
|
enforcer, err := tokenGo.NewEnforcer(adapter)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewEnforcer() failed: %v", err)
|
||||||
|
}
|
||||||
|
// enable logger
|
||||||
|
enforcer.EnableLog()
|
||||||
|
ssoOptions := &config.SsoOptions{
|
||||||
|
Mode: "",
|
||||||
|
TicketTimeout: 300,
|
||||||
|
AllowUrl: "*",
|
||||||
|
IsSlo: true,
|
||||||
|
|
||||||
|
IsHttp: true,
|
||||||
|
ServerUrl: "http://token-go-sso-server.com:9000",
|
||||||
|
NotLoginView: func() interface{} {
|
||||||
|
msg := "not log in SSO-Server, please visit <a href='/sso/doLogin?name=tokengo&pwd=123456' target='_blank'> doLogin </a>"
|
||||||
|
return msg
|
||||||
|
},
|
||||||
|
DoLoginHandle: func(name string, pwd string, ctx ctx.Context) (interface{}, error) {
|
||||||
|
if name != "tokengo" {
|
||||||
|
return "name error", nil
|
||||||
|
}
|
||||||
|
if pwd != "123456" {
|
||||||
|
return "pwd error", nil
|
||||||
|
}
|
||||||
|
token, err := enforcer.Login("1001", ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return model.Ok().SetData(token), nil
|
||||||
|
},
|
||||||
|
SendHttp: func(url string) (string, error) {
|
||||||
|
return SendGetRequest(url)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
signOptions := &config.SignOptions{
|
||||||
|
SecretKey: "kQwIOrYvnXmSDkwEiFngrKidMcdrgKor",
|
||||||
|
IsCheckNonce: true,
|
||||||
|
}
|
||||||
|
ssoEnforcer, err := NewSsoEnforcer(&Options{
|
||||||
|
SsoOptions: ssoOptions,
|
||||||
|
SignOptions: signOptions,
|
||||||
|
Enforcer: enforcer,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewSsoEnforcer() failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("enforcer: %v", ssoEnforcer)
|
||||||
|
}
|
||||||
|
func TestNewSsoClient3Enforcer(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
// use default adapter
|
||||||
|
adapter := tokenGo.NewDefaultAdapter()
|
||||||
|
enforcer, err := tokenGo.NewEnforcer(adapter)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewEnforcer() failed: %v", err)
|
||||||
|
}
|
||||||
|
// enable logger
|
||||||
|
enforcer.EnableLog()
|
||||||
|
ssoOptions := &config.SsoOptions{
|
||||||
|
AuthUrl: "/sso/auth",
|
||||||
|
IsSlo: true,
|
||||||
|
IsHttp: true,
|
||||||
|
SloUrl: "/sso/signout",
|
||||||
|
CheckTicketUrl: "/sso/checkTicket",
|
||||||
|
ServerUrl: "http://token-go-sso-server.com:9000",
|
||||||
|
SendHttp: func(url string) (string, error) {
|
||||||
|
return SendGetRequest(url)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
signOptions := &config.SignOptions{
|
||||||
|
SecretKey: "kQwIOrYvnXmSDkwEiFngrKidMcdrgKor",
|
||||||
|
IsCheckNonce: true,
|
||||||
|
}
|
||||||
|
ssoEnforcer, err := NewSsoEnforcer(&Options{
|
||||||
|
SsoOptions: ssoOptions,
|
||||||
|
SignOptions: signOptions,
|
||||||
|
Enforcer: enforcer,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewSsoEnforcer() failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("enforcer: %v", ssoEnforcer)
|
||||||
|
}
|
13
util/secure.go
Normal file
13
util/secure.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func MD5(str string) string {
|
||||||
|
data := []byte(str)
|
||||||
|
has := md5.Sum(data)
|
||||||
|
md5str := fmt.Sprintf("%x", has)
|
||||||
|
return md5str
|
||||||
|
}
|
12
util/sign.go
Normal file
12
util/sign.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsValidTimeStamp determine whether the gap between the startTime and the current timestamp is within the allowable range.
|
||||||
|
func IsValidTimeStamp(startTime int64, allowDisparity int64) bool {
|
||||||
|
nowDisparity := time.Now().UnixMilli() - startTime
|
||||||
|
|
||||||
|
return allowDisparity == 1 || nowDisparity <= allowDisparity
|
||||||
|
}
|
137
util/url.go
Normal file
137
util/url.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SendGetRequest(url string) (string, error) {
|
||||||
|
response, err := http.Get(url)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("http.Get() failed: %v", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func(Body io.ReadCloser) {
|
||||||
|
err = Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("read response body failed: %v", err)
|
||||||
|
}
|
||||||
|
}(response.Body)
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("ioutil.ReadAll() failed: %v", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(body), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpliceUrl splice two url.
|
||||||
|
// Examples:
|
||||||
|
// u1 = "http://domain.com" u2 = "/sso/auth" return http://domain.com/sso/auth
|
||||||
|
func SpliceUrl(u1 string, u2 string) string {
|
||||||
|
if u1 == "" {
|
||||||
|
return u2
|
||||||
|
}
|
||||||
|
if u2 == "" {
|
||||||
|
return u1
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(u2, "http") {
|
||||||
|
return u2
|
||||||
|
}
|
||||||
|
|
||||||
|
return u1 + u2
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasUrl(urls []string, url string) bool {
|
||||||
|
for _, s := range urls {
|
||||||
|
if MatchUrl(s, url) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func MatchUrl(pattern string, url string) bool {
|
||||||
|
if pattern == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return pattern == url
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsValidUrl(u1 string) bool {
|
||||||
|
|
||||||
|
_, err := url.ParseRequestURI(u1)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(u1)
|
||||||
|
if err != nil || u.Scheme == "" || u.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if the URL has a valid scheme (http or https)
|
||||||
|
if u.Scheme != "http" && u.Scheme != "https" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQueryMap add map param for the path.
|
||||||
|
func AddQueryMap(path string, paramMap map[string]string) string {
|
||||||
|
queryString := MapToQuery(paramMap)
|
||||||
|
|
||||||
|
return AddQueryValue(path, queryString)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddQueryValue(path string, queryString string) string {
|
||||||
|
index := strings.LastIndex(path, "?")
|
||||||
|
// if the path is not included "?"
|
||||||
|
if index == -1 {
|
||||||
|
return path + "?" + queryString
|
||||||
|
}
|
||||||
|
// if the last is "?"
|
||||||
|
if index == len(path)-1 {
|
||||||
|
return path + queryString
|
||||||
|
}
|
||||||
|
|
||||||
|
// if "?" inside path, the last is not "&" and queryString's first string is not "&"
|
||||||
|
if index < len(path)-1 {
|
||||||
|
if strings.LastIndex(path, "&") != len(path)-1 && strings.Index(path, "&") != 0 {
|
||||||
|
return path + "&" + queryString
|
||||||
|
} else {
|
||||||
|
return path + queryString
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddQuery add query param for the path.
|
||||||
|
func AddQuery(path string, key string, value string) string {
|
||||||
|
queryString := key + "=" + value
|
||||||
|
return AddQueryValue(path, queryString)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapToQuery convert map to k=v array, and use "&" to join.
|
||||||
|
func MapToQuery(paramMap map[string]string) string {
|
||||||
|
var queryString []string
|
||||||
|
for k, v := range paramMap {
|
||||||
|
queryString = append(queryString, k+"="+v)
|
||||||
|
}
|
||||||
|
query := strings.Join(queryString, "&")
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func Encode(u string) string {
|
||||||
|
return url.QueryEscape(u)
|
||||||
|
}
|
144
util/url_test.go
Normal file
144
util/url_test.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSpliceUrl(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
u1 string
|
||||||
|
u2 string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "success1",
|
||||||
|
args: args{
|
||||||
|
u1: "http://domain.com",
|
||||||
|
u2: "/sso/auth",
|
||||||
|
},
|
||||||
|
want: "http://domain.com/sso/auth",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "success2",
|
||||||
|
args: args{
|
||||||
|
u1: "",
|
||||||
|
u2: "http://domain.com/sso/auth",
|
||||||
|
},
|
||||||
|
want: "http://domain.com/sso/auth",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := SpliceUrl(tt.args.u1, tt.args.u2); got != tt.want {
|
||||||
|
t.Errorf("SpliceUrl() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasUrl(t *testing.T) {
|
||||||
|
strings := []string{"1", "2", "3"}
|
||||||
|
hasUrl := HasUrl(strings, "2")
|
||||||
|
if !hasUrl {
|
||||||
|
t.Errorf("HasUrl() = %v, want %v", false, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMatchUrl(t *testing.T) {
|
||||||
|
all := "*"
|
||||||
|
if !MatchUrl(all, "123") {
|
||||||
|
t.Errorf("HasUrl() = %v, want %v", false, true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidUrl(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
u1 string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "error url1",
|
||||||
|
args: args{u1: "htp:23/asd"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "validated",
|
||||||
|
args: args{u1: "http://123.com:90//ac"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "error url2",
|
||||||
|
args: args{u1: "http:/123.com:90//ac"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := IsValidUrl(tt.args.u1); got != tt.want {
|
||||||
|
t.Errorf("IsValidUrl() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddQueryMap(t *testing.T) {
|
||||||
|
m := make(map[string]string)
|
||||||
|
m["1"] = "2"
|
||||||
|
m["2"] = "3"
|
||||||
|
|
||||||
|
if got := AddQueryMap("/sso/auth", m); got != "/sso/auth?1=2&2=3" && got != "/sso/auth?2=3&1=2" {
|
||||||
|
t.Errorf("AddQueryMap() = %v, want %v", got, "/sso/auth?1=2&2=3")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := AddQueryMap("/sso/auth?", m); got != "/sso/auth?1=2&2=3" && got != "/sso/auth?2=3&1=2" {
|
||||||
|
t.Errorf("AddQueryMap() = %v, want %v", got, "/sso/auth?1=2&2=3")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddQuery(t *testing.T) {
|
||||||
|
if got := AddQuery("/sso", "1", "2"); got != "/sso?1=2" {
|
||||||
|
t.Errorf("AddQuery() = %v, want %v", got, "/sso?1=2")
|
||||||
|
}
|
||||||
|
if got := AddQuery("/sso?", "1", "2"); got != "/sso?1=2" {
|
||||||
|
t.Errorf("AddQuery() = %v, want %v", got, "/sso?1=2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToQuery(t *testing.T) {
|
||||||
|
m := make(map[string]string)
|
||||||
|
m["1"] = "2"
|
||||||
|
query := MapToQuery(m)
|
||||||
|
if query != "1=2" {
|
||||||
|
t.Errorf("MapToQuery() = %v, want %v", query, "1=2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncode(t *testing.T) {
|
||||||
|
encode := Encode("abc123==123")
|
||||||
|
log.Print("Encode(\"abc123==123\") = " + encode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddQueryValue(t *testing.T) {
|
||||||
|
if got := AddQueryValue("/sso/auth?back=http://123.com/login", "ticket=23324"); got != "/sso/auth?back=http://123.com/login&ticket=23324" {
|
||||||
|
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login&ticket=23324")
|
||||||
|
}
|
||||||
|
if got := AddQueryValue("/sso/auth?back=http://123.com/login?", "ticket=23324"); got != "/sso/auth?back=http://123.com/login?ticket=23324" {
|
||||||
|
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=23324")
|
||||||
|
}
|
||||||
|
if got := AddQueryValue("/sso/auth?back=http://123.com/login?ticket=123", "redirect=23324"); got != "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324" {
|
||||||
|
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324")
|
||||||
|
}
|
||||||
|
if got := AddQueryValue("/sso/auth?back=http://123.com/login?ticket=123&", "redirect=23324"); got != "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324" {
|
||||||
|
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324")
|
||||||
|
}
|
||||||
|
}
|
@@ -26,3 +26,12 @@ func InterfaceToBytes(data interface{}) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unable to convert %T to []byte", data)
|
return nil, fmt.Errorf("unable to convert %T to []byte", data)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AppendStr do not add repeated str.
|
||||||
|
// If old slice has newStr, return directly, else append
|
||||||
|
func AppendStr(old []string, newStr string) []string {
|
||||||
|
if HasStr(old, newStr) {
|
||||||
|
return old
|
||||||
|
}
|
||||||
|
return append(old, newStr)
|
||||||
|
}
|
||||||
|
22
util/util_test.go
Normal file
22
util/util_test.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAppendStr(t *testing.T) {
|
||||||
|
strings := []string{"1", "2"}
|
||||||
|
str := AppendStr(strings, "1")
|
||||||
|
str = AppendStr(str, "3")
|
||||||
|
for i, s := range strings {
|
||||||
|
if s == "1" && i == 2 {
|
||||||
|
t.Errorf("AppendStr() = %v, want %v", str, []string{"1", "2"})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
for i, s := range strings {
|
||||||
|
if i == 2 && s != "3" {
|
||||||
|
t.Errorf("AppendStr() = %v, want %v", str, []string{"1", "2", "3"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user