feat: support SSO

This commit is contained in:
weloe
2023-06-09 03:30:51 +08:00
parent ecef85c8af
commit 446789e261
20 changed files with 1824 additions and 1 deletions

View File

@@ -1,6 +1,6 @@
# Token-Go # Token-Go
This library focuses on solving login authentication problems, such as: login, multi-account login, shared token, logout, kickout ... This library focuses on solving login authentication problems, such as: login, multi-account login, shared token, logout, kickout, banned, SSO ...
## Installation ## Installation
@@ -189,6 +189,9 @@ func CheckAuth(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "you have authorization") fmt.Fprintf(w, "you have authorization")
} }
``` ```
## SSO
## Api ## Api

30
config/sign.go Normal file
View File

@@ -0,0 +1,30 @@
package config
// SignConfig sign config
type SignConfig struct {
SecretKey string
TimeStampDisparity int64
IsCheckNonce bool
}
func NewSignConfig(options *SignOptions) (*SignConfig, error) {
if options == nil {
options = &SignOptions{}
}
if options.TimeStampDisparity == 0 {
options.TimeStampDisparity = 1000 * 60 * 15
}
return &SignConfig{
SecretKey: options.SecretKey,
TimeStampDisparity: options.TimeStampDisparity,
IsCheckNonce: options.IsCheckNonce,
}, nil
}
func (s *SignConfig) GetSaveNonceExpire() int64 {
if s.TimeStampDisparity >= 0 {
return s.TimeStampDisparity / 1000
} else {
return 60 * 60 * 24
}
}

168
config/sso.go Normal file
View File

@@ -0,0 +1,168 @@
package config
import (
"errors"
"github.com/weloe/token-go/ctx"
model2 "github.com/weloe/token-go/model"
"github.com/weloe/token-go/util"
"strings"
)
func DefaultSsoConfig(serverUrl string, notLoginView func() interface{},
doLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error),
ticketResultHandle func(o1 string, s string) (interface{}, error),
sendHttp func(url string) (string, error)) *SsoConfig {
if notLoginView == nil {
notLoginView = func() interface{} {
return "not logged in to the SSO-Server"
}
}
return &SsoConfig{
Mode: "",
TicketTimeout: 60 * 5,
AllowUrl: "*",
IsSlo: true,
IsHttp: false,
Client: "",
AuthUrl: "/sso/auth",
CheckTicketUrl: "/sso/checkTicket",
GetDataUrl: "/sso/getData",
UserInfoUrl: "/sso/userInfo",
SloUrl: "/sso/signOut",
SsoLogoutCall: "",
ServerUrl: serverUrl,
NotLoginView: notLoginView,
DoLoginHandle: doLoginHandle,
TicketResultHandle: ticketResultHandle,
SendHttp: sendHttp,
}
}
func NewSsoConfig(options *SsoOptions) (*SsoConfig, error) {
if options == nil {
options = &SsoOptions{}
}
if options.TicketTimeout == 0 {
options.TicketTimeout = 60 * 5
}
if options.AllowUrl == "" {
options.AllowUrl = "*"
}
if options.AuthUrl == "" {
options.AuthUrl = "/sso/auth"
}
if options.CheckTicketUrl == "" {
options.CheckTicketUrl = "/sso/checkTicket"
}
if options.GetDataUrl == "" {
options.GetDataUrl = "/sso/getData"
}
if options.UserInfoUrl == "" {
options.UserInfoUrl = "/sso/userInfo"
}
if options.SloUrl == "" {
options.SloUrl = "/sso/signout"
}
if options.NotLoginView == nil {
options.NotLoginView = func() interface{} {
return "not logged in to the SSO-Server"
}
}
if options.DoLoginHandle == nil {
options.DoLoginHandle = func(name string, pwd string, ctx ctx.Context) (interface{}, error) {
return model2.Error(), errors.New("SsoConfig.DoLoginHandle is nil")
}
}
if options.IsHttp && options.SendHttp == nil {
return nil, errors.New("please config SSO SentHttp")
}
return &SsoConfig{
Mode: options.Mode,
TicketTimeout: options.TicketTimeout,
AllowUrl: options.AllowUrl,
IsSlo: options.IsSlo,
IsHttp: options.IsHttp,
Client: options.Client,
AuthUrl: options.AuthUrl,
CheckTicketUrl: options.CheckTicketUrl,
GetDataUrl: options.GetDataUrl,
UserInfoUrl: options.UserInfoUrl,
SloUrl: options.SloUrl,
SsoLogoutCall: options.SsoLogoutCall,
ServerUrl: options.ServerUrl,
NotLoginView: options.NotLoginView,
DoLoginHandle: options.DoLoginHandle,
TicketResultHandle: options.TicketResultHandle,
SendHttp: options.SendHttp,
}, nil
}
type SsoConfig struct {
// Mode sso mode
Mode string
// TicketTimeout ticket timeout
TicketTimeout int64
// AllowUrl All allowed authorization callback addresses, separated by ','
AllowUrl string
IsSlo bool
IsHttp bool
// SSO-Client current client name
Client string
// SSO-Server auth url
AuthUrl string
// SSO-Server check ticket url
CheckTicketUrl string
GetDataUrl string
UserInfoUrl string
SloUrl string
SsoLogoutCall string
ServerUrl string
/**
sso callback func
*/
// NotLoginView
NotLoginView func() interface{}
// DoLoginHandle login func
DoLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error)
// TicketResultHandle called each time the result of the validation ticket is obtained from the SSO-Server
TicketResultHandle func(loginId string, back string) (interface{}, error)
// SendHttp sent http
SendHttp func(url string) (string, error)
}
// SpliceAuthUrl return Server-side single sign-on authorization address
func (c *SsoConfig) SpliceAuthUrl() string {
return util.SpliceUrl(c.ServerUrl, c.AuthUrl)
}
// SpliceCheckTicketUrl return ticket verification address on the server side
func (c *SsoConfig) SpliceCheckTicketUrl() string {
return util.SpliceUrl(c.ServerUrl, c.CheckTicketUrl)
}
func (c *SsoConfig) SpliceGetDataUrl() string {
return util.SpliceUrl(c.ServerUrl, c.GetDataUrl)
}
func (c *SsoConfig) SpliceUserInfoUrl() string {
return util.SpliceUrl(c.ServerUrl, c.UserInfoUrl)
}
func (c *SsoConfig) SpliceSloUrl() string {
return util.SpliceUrl(c.ServerUrl, c.SloUrl)
}
// SetAllow set allow callback url
func (c *SsoConfig) SetAllow(urls ...string) *SsoConfig {
c.AllowUrl = strings.Join(urls, ",")
return c
}

52
config/sso_options.go Normal file
View File

@@ -0,0 +1,52 @@
package config
import "github.com/weloe/token-go/ctx"
// SsoOptions new SsoConfig options.
type SsoOptions struct {
CookieDomain string
// Mode sso mode
Mode string
// TicketTimeout ticket timeout
TicketTimeout int64
// AllowUrl All allowed authorization callback addresses, separated by ','
AllowUrl string
IsSlo bool
IsHttp bool
// SSO-Client current client name
Client string
// SSO-Server auth url
AuthUrl string
// SSO-Server check ticket url
CheckTicketUrl string
GetDataUrl string
UserInfoUrl string
SloUrl string
SsoLogoutCall string
ServerUrl string
/**
sso callback func
*/
// NotLoginView
NotLoginView func() interface{}
// DoLoginHandle login func
DoLoginHandle func(name string, pwd string, ctx ctx.Context) (interface{}, error)
// TicketResultHandle called each time the result of the validation ticket is obtained from the SSO-Server
TicketResultHandle func(loginId string, back string) (interface{}, error)
// SendHttp sent http
SendHttp func(url string) (string, error)
}
// SignOptions SignConfig options
type SignOptions struct {
SecretKey string
TimeStampDisparity int64
IsCheckNonce bool
}

11
constant/sso.go Normal file
View File

@@ -0,0 +1,11 @@
package constant
const (
SLO_CALLBACK_SET_KEY = "SLO_CALLBACK_SET_KEY_"
SELF = "self"
MODE_SIMPLE = "simple"
NOT_HANDLE = "{\"msg\": \"not handle\"}"
)

39
model/http_result.go Normal file
View File

@@ -0,0 +1,39 @@
package model
const (
SUCCESS = 1
ERROR = 0
)
// Result wrap the http request result.
type Result struct {
Code int
Msg string
Data interface{}
}
func Ok() *Result {
return &Result{
Code: SUCCESS,
Msg: "success",
Data: nil,
}
}
func Error() *Result {
return &Result{
Code: -1,
Msg: "error",
Data: nil,
}
}
func (r *Result) SetData(data interface{}) *Result {
r.Data = data
return r
}
func (r *Result) SetMsg(msg string) *Result {
r.Msg = msg
return r
}

34
sso/api.go Normal file
View File

@@ -0,0 +1,34 @@
package sso
// ApiName sso api name, used to dispatcher request.
type ApiName struct {
// sso-server auth url
SsoAuth string
// sso-server rest api login url
SsoDoLogin string
// sso-server check ticket url
SsoCheckTicket string
// sso-server get user info url
SsoUserInfo string
// sso-server single logout url
SsoSignout string
// sso-client login url
SsoLogin string
// sso-client single logout url
SsoLogout string
// sso-client logout callback url
SsoLogoutCall string
}
func DefaultApiName() *ApiName {
return &ApiName{
SsoAuth: "/sso/auth",
SsoDoLogin: "/sso/doLogin",
SsoCheckTicket: "/sso/checkTicket",
SsoUserInfo: "/sso/userInfo",
SsoSignout: "/sso/signout",
SsoLogin: "/sso/login",
SsoLogout: "/sso/logout",
SsoLogoutCall: "/sso/logoutCall",
}
}

39
sso/param.go Normal file
View File

@@ -0,0 +1,39 @@
package sso
// ParamName http request param name.
type ParamName struct {
Redirect string
Ticket string
Back string
Mode string
LoginId string
Client string
SsoLogoutCall string
Name string
Pwd string
//=== sign param
TimeStamp string
Nonce string
Sign string
SecretKet string
}
func DefaultParamName() *ParamName {
return &ParamName{
Redirect: "redirect",
Ticket: "ticket",
Back: "back",
Mode: "mode",
LoginId: "loginId",
Client: "client",
SsoLogoutCall: "ssoLogoutCall",
Name: "name",
Pwd: "pwd",
TimeStamp: "timestamp",
Nonce: "nonce",
Sign: "sign",
SecretKet: "key",
}
}

108
sso/sso.go Normal file
View File

@@ -0,0 +1,108 @@
package sso
import (
"errors"
tokenGo "github.com/weloe/token-go"
"github.com/weloe/token-go/config"
"github.com/weloe/token-go/constant"
"github.com/weloe/token-go/ctx"
"github.com/weloe/token-go/model"
)
// Options construct options
type Options struct {
SsoOptions *config.SsoOptions
SignOptions *config.SignOptions
Enforcer tokenGo.IEnforcer
}
type SsoEnforcer struct {
apiName *ApiName
paramName *ParamName
config *config.SsoConfig
signConfig *config.SignConfig
enforcer tokenGo.IEnforcer
}
// NewSsoEnforcer create sso enforcer.
// If the available field in the parameter is empty, use the default value,
// if the required field is empty, return nil and error.
func NewSsoEnforcer(options *Options) (*SsoEnforcer, error) {
if options.Enforcer == nil {
return nil, errors.New("Options.Enforcer can not be nil")
}
if options.SsoOptions.CookieDomain != "" {
options.Enforcer.GetTokenConfig().CookieConfig.Domain = options.SsoOptions.CookieDomain
}
ssoConfig, err := config.NewSsoConfig(options.SsoOptions)
if err != nil {
return nil, err
}
signConfig, err := config.NewSignConfig(options.SignOptions)
if err != nil {
return nil, err
}
return &SsoEnforcer{
apiName: DefaultApiName(),
paramName: DefaultParamName(),
config: ssoConfig,
signConfig: signConfig,
enforcer: options.Enforcer,
}, nil
}
func (s *SsoEnforcer) SetApi(apiName *ApiName) {
s.apiName = apiName
}
func (s *SsoEnforcer) GetApi() *ApiName {
return s.apiName
}
func (s *SsoEnforcer) SetParamName(paramName *ParamName) {
s.paramName = paramName
}
func (s *SsoEnforcer) GetParamName() *ParamName {
return s.paramName
}
func (s *SsoEnforcer) SetSsoConfig(config *config.SsoConfig) {
s.config = config
}
func (s *SsoEnforcer) GetSsoConfig() *config.SsoConfig {
return s.config
}
func (s *SsoEnforcer) SetSignConfig(config *config.SignConfig) {
s.signConfig = config
}
func (s *SsoEnforcer) GetSignConfig() *config.SignConfig {
return s.signConfig
}
// ssoLogoutBack single-logout callback, redirect to ParamName.Back url.
// If http request has back param and value is SELF, redirect to previous page,
// if http request back param and value is validated url, redirect to url,
// else return model.Ok() directly.
func (s *SsoEnforcer) ssoLogoutBack(ctx ctx.Context) (interface{}, error) {
paramName := s.paramName
request := ctx.Request()
response := ctx.Response()
back := request.Query(paramName.Back)
if back != "" {
if back == constant.SELF {
return "<script>if(document.referrer != location.href){ location.replace(document.referrer || '/'); }</script>", nil
}
response.Redirect(back)
return nil, nil
} else {
// back is nil
return model.Ok().SetMsg("back is nil, not redirect"), nil
}
}

206
sso/sso_client_api.go Normal file
View File

@@ -0,0 +1,206 @@
package sso
import (
"errors"
"github.com/weloe/token-go/ctx"
"github.com/weloe/token-go/model"
"strings"
)
/**
=========processor SSO-Client api
*/
// SsoClientLogin SSO-Client login.
func (s *SsoEnforcer) SsoClientLogin(ctx ctx.Context) (interface{}, error) {
request := ctx.Request()
response := ctx.Response()
paramName := s.paramName
apiName := s.apiName
// get back value and redirect
back := request.Query(paramName.Back)
if back == "" {
back = "/"
}
isLogin, err := s.enforcer.IsLogin(ctx)
if err != nil {
return nil, err
}
// if the current client is already logged in, there is no need to redirect to the SSO-Server
if isLogin {
response.Redirect(back)
return nil, nil
}
// if isLogin == false, attempt to get ticket
ticket := request.Query(paramName.Ticket)
// if ticket == "", redirect to SSO-Server, the default path is /sso/auth
if ticket == "" {
serverAuthUrl, err := s.buildServerAuthUrl(request.UrlNoQuery(), back)
if err != nil {
return nil, err
}
response.Redirect(serverAuthUrl)
return nil, nil
}
// else ticket != "", need to log in by ticket
var loginId string
// get current path
ssoLoginUrl := apiName.SsoLogin
if s.config.IsHttp {
// use http request to SSO-Server to check ticket in SSO-Server
var ssoLogoutCall string
if s.config.IsSlo {
// get logout callback url
if s.config.SsoLogoutCall != "" {
ssoLogoutCall = s.config.SsoLogoutCall
} else if ssoLoginUrl != "" {
ssoLogoutCall = strings.ReplaceAll(request.UrlNoQuery(), ssoLoginUrl, apiName.SsoLogoutCall)
}
}
// send http to check
checkTicketUrl, err := s.buildCheckTicketUrl(ticket, ssoLogoutCall)
if err != nil {
return nil, err
}
// send http request
resp, err := s.request(checkTicketUrl)
if err != nil {
return nil, err
}
if resp.Code == model.ERROR {
return nil, errors.New("request failed: " + resp.Msg)
}
loginId = resp.Data.(string)
} else {
// use adapter to check ticket
loginId, err = s.CheckTicket(ticket)
if err != nil {
return nil, err
}
}
// if set callback TicketResultHandle
if s.config.TicketResultHandle != nil {
return s.config.TicketResultHandle(loginId, back)
}
// if loginId == "", return error
if loginId == "" {
return nil, errors.New("invalid ticket: " + ticket)
}
// login
_, err = s.enforcer.Login(loginId, ctx)
if err != nil {
return nil, err
}
// redirect to back
response.Redirect(back)
return nil, nil
}
// SsoClientLogout SSO-Client single-logout.
func (s *SsoEnforcer) SsoClientLogout(ctx ctx.Context) (interface{}, error) {
ssoConfig := s.config
// enable single-logout and isHttp == false
if ssoConfig.IsSlo && !ssoConfig.IsHttp {
isLogin, err := s.enforcer.IsLogin(ctx)
if err != nil {
return nil, err
}
// check if you are logged in
if isLogin {
var id string
id, err = s.enforcer.GetLoginId(ctx)
if err != nil {
return nil, err
}
// client logout
err = s.enforcer.LogoutById(id)
if err != nil {
return nil, err
}
// callback
return s.ssoLogoutBack(ctx)
}
}
// enable single-logout and isHttp
if ssoConfig.IsSlo && ssoConfig.IsHttp {
isLogin, err := s.enforcer.IsLogin(ctx)
if err != nil {
return nil, err
}
if !isLogin {
return s.ssoLogoutBack(ctx)
}
// get id
var id string
id, err = s.enforcer.GetLoginId(ctx)
if err != nil {
return nil, err
}
// use SSO-Server single-logout api to logout
sloUrl, err := s.buildSloUrl(id)
if err != nil {
return nil, err
}
res, err := s.request(sloUrl)
if err != nil {
return nil, err
}
// check response data
if res.Code == model.SUCCESS {
login, _ := s.enforcer.IsLogin(ctx)
if login {
err := s.enforcer.Logout(ctx)
if err != nil {
return nil, err
}
}
return s.ssoLogoutBack(ctx)
} else {
return nil, errors.New("request failed: " + res.Msg)
}
}
return nil, errors.New("not handle")
}
// SsoClientLogoutCall client logout callback.
func (s *SsoEnforcer) SsoClientLogoutCall(ctx ctx.Context) (interface{}, error) {
request := ctx.Request()
loginId := request.Query(s.paramName.LoginId)
if loginId == "" {
return nil, errors.New("request param must include loginId")
}
// check request param
err := s.checkRequest(request)
if err != nil {
return nil, err
}
// logout
err = s.enforcer.Logout(ctx)
if err != nil {
return nil, err
}
return model.Ok().SetMsg("logout callback success"), nil
}
// GetData client build url to sent http to get data from SSO-Server.
func (s *SsoEnforcer) GetData(paramMap map[string]string) (interface{}, error) {
finalUrl, err := s.buildGetDataUrl(paramMap)
if err != nil {
return nil, err
}
res, err := s.config.SendHttp(finalUrl)
return res, err
}

64
sso/sso_dispatcher_api.go Normal file
View File

@@ -0,0 +1,64 @@
package sso
import (
"github.com/weloe/token-go/ctx"
"github.com/weloe/token-go/model"
)
/**
=========dispatcher api
*/
// ServerDisPatcher dispatcher SSO-Server api, returns model.Result or string.
func (s *SsoEnforcer) ServerDisPatcher(ctx ctx.Context) interface{} {
request := ctx.Request()
apiName := s.apiName
path := request.Path()
var res interface{}
var err error
if path == apiName.SsoAuth {
res, err = s.SsoAuth(ctx)
} else if path == apiName.SsoDoLogin {
res, err = s.SsoDoLogin(ctx)
} else if path == apiName.SsoCheckTicket && s.config.IsHttp {
res, err = s.SsoCheckTicket(ctx)
} else if path == apiName.SsoSignout {
res, err = s.SsoSignOut(ctx)
} else {
return model.Error().SetMsg("not handle")
}
if err != nil {
return model.Error().SetMsg(err.Error())
}
if res == nil {
return model.Ok()
}
return res
}
// ClientDispatcher dispatcher Client api, returns model.Result or string.
func (s *SsoEnforcer) ClientDispatcher(ctx ctx.Context) interface{} {
request := ctx.Request()
apiName := s.apiName
path := request.Path()
var res interface{}
var err error
if path == apiName.SsoLogin {
res, err = s.SsoClientLogin(ctx)
} else if path == apiName.SsoLogout {
res, err = s.SsoClientLogout(ctx)
} else if path == apiName.SsoLogoutCall && s.config.IsSlo && s.config.IsHttp {
res, err = s.SsoClientLogoutCall(ctx)
} else {
return model.Error().SetMsg("not handle")
}
if err != nil {
return model.Error().SetMsg(err.Error())
}
if res == nil {
return model.Ok()
}
return res
}

494
sso/sso_internal_api.go Normal file
View File

@@ -0,0 +1,494 @@
package sso
import (
"encoding/json"
"errors"
"fmt"
"github.com/weloe/token-go/constant"
"github.com/weloe/token-go/ctx"
"github.com/weloe/token-go/model"
"github.com/weloe/token-go/util"
"log"
"strconv"
"strings"
"time"
)
/**
=========internal api
*/
// checkRequest check request param timestamp,nonce,sign.
func (s *SsoEnforcer) checkRequest(request ctx.Request) error {
timestamp := request.Query(s.paramName.TimeStamp)
nonce := request.Query(s.paramName.Nonce)
sign := request.Query(s.paramName.Sign)
err := s.checkTimeStamp(timestamp)
if err != nil {
return err
}
if s.signConfig.IsCheckNonce {
err = s.checkNonce(nonce)
if err != nil {
return err
}
}
err = s.checkSign(timestamp, nonce, sign)
if err != nil {
return err
}
return nil
}
// CreateTicket create ticket by account-id.
func (s *SsoEnforcer) CreateTicket(loginId string, client string) (string, error) {
// create random string ticket
ticket, err := util.GenerateRandomString64()
if err != nil {
return "", err
}
// save ticket-id+client
err = s.saveTicket(ticket, loginId, client)
if err != nil {
return "", err
}
// save id-ticket
err = s.saveTicketIndex(ticket, loginId)
if err != nil {
return "", err
}
return ticket, nil
}
// GetLoginId get loginId by ticket.
func (s *SsoEnforcer) GetLoginId(ticket string) string {
if ticket == "" {
return ""
}
loginId := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
if loginId != "" && strings.Contains(loginId, ",") {
split := strings.Split(loginId, ",")
loginId = split[0]
}
return loginId
}
// GetTicket get ticket by loginId.
func (s *SsoEnforcer) GetTicket(loginId string) string {
if loginId == "" {
return ""
}
return s.enforcer.GetAdapter().GetStr(s.spliceTicketIndexKey(loginId))
}
// CheckTicket use config.Client to check ticket,return loginId.
func (s *SsoEnforcer) CheckTicket(ticket string) (string, error) {
return s.CheckTicketByClient(ticket, s.config.Client)
}
// CheckTicketByClient check ticket by pointing client,return loginId.
func (s *SsoEnforcer) CheckTicketByClient(ticket string, client string) (string, error) {
id := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
if id == "" {
return "", nil
}
var ticketClient string
if strings.Contains(id, ",") {
split := strings.Split(id, ",")
id = split[0]
ticketClient = split[1]
}
if client != "" && client != ticketClient {
return "", fmt.Errorf("the ticket does not belong to the client, client: %v, ticket: %v", client, ticket)
}
err := s.deleteTicket(ticket)
if err != nil {
return "", err
}
err = s.deleteTicketIndex(id)
if err != nil {
return "", err
}
return id, nil
}
// CheckRedirectUrl check redirectUrl.
func (s *SsoEnforcer) CheckRedirectUrl(url string) error {
if !util.IsValidUrl(url) {
return fmt.Errorf("invalid redirect url: %v", url)
}
index := strings.Index(url, "?")
if index != -1 {
url = url[0:index]
}
allowUrls := strings.Split(s.GetAllowUrl(), ",")
if !util.HasUrl(allowUrls, url) {
return fmt.Errorf("illegal redirect url: %v", url)
}
return nil
}
// RegisterSloCallbackUrl register the URL of the single logout callback for the account id.
func (s *SsoEnforcer) RegisterSloCallbackUrl(loginId string, sloCallbackUrl string) error {
if loginId == "" || sloCallbackUrl == "" {
return nil
}
session := s.enforcer.GetSession(loginId)
// splice session id
sessionId := s.enforcer.GetTokenConfig().TokenName + ":" + s.enforcer.GetType() + ":session:" + loginId
if session == nil {
session = model.NewSession(sessionId, "account-session", loginId)
}
value := session.Get(constant.SLO_CALLBACK_SET_KEY)
var v []string
if value != nil {
sv, ok := value.([]string)
v = sv
if !ok {
return errors.New("session SLO_CALLBACK_SET_KEY_ data convert into []string failed")
}
}
v = util.AppendStr(v, sloCallbackUrl)
session.Set(sessionId, v)
// update session
err := s.enforcer.UpdateSession(loginId, session)
if err != nil {
return err
}
return nil
}
// ssoSignOutById single sign-out of the specified account.
// Use loginId to get single sign-out urls from session.
func (s *SsoEnforcer) ssoSignOutById(loginId string) error {
// if loginId is not logged, return error
session := s.enforcer.GetSession(loginId)
if session == nil {
return errors.New("this loginId is not logged in")
}
value := session.Get(constant.SLO_CALLBACK_SET_KEY)
if value == nil {
return nil
}
urls, ok := value.([]string)
if !ok {
return errors.New("convert into []string failed")
}
// range urls to make client logout
for _, url := range urls {
// join url
newUrl, err := s.joinLoginIdAndSign(url, loginId)
if err != nil {
return err
}
// sent http
_, err = s.config.SendHttp(newUrl)
if err != nil {
return err
}
}
// server logout
return s.enforcer.LogoutById(loginId)
}
// joinLoginIdAndSign splice the loginId to the url, and stitching parameters such as sign.
func (s *SsoEnforcer) joinLoginIdAndSign(url string, id string) (string, error) {
nonce, err := util.GenerateRandomString32()
if err != nil {
return "", err
}
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
sign, err := s.createSign(timestamp, nonce)
if err != nil {
return "", err
}
str := url + "?" + s.paramName.LoginId + "=" + id + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.Nonce + "=" + nonce + "&" + s.paramName.Sign + "=" + sign
return str, nil
}
// buildServerAuthUrl SSO-Client build SSO-Server single sign-on url.
func (s *SsoEnforcer) buildServerAuthUrl(clientLoginUrl string, back string) (string, error) {
if clientLoginUrl == "" {
return "", errors.New("arg[0] clientLoginUrl can not be nil")
}
// get server auth url
authUrl := s.config.SpliceAuthUrl()
client := s.config.Client
if client != "" {
authUrl = util.AddQuery(authUrl, s.paramName.Client, client)
}
// splice back url
if back != "" {
back = util.Encode(back)
clientLoginUrl = util.AddQuery(clientLoginUrl, s.paramName.Back, back)
}
return util.AddQuery(authUrl, s.paramName.Redirect, clientLoginUrl), nil
}
// buildRedirectUrl the server gives the redirectUrl of the ticket to the client.
// Check redirect url, delete old ticket of loginId and create new ticket, then return url with new ticket.
func (s *SsoEnforcer) buildRedirectUrl(loginId string, client string, redirect string) (string, error) {
// check redirect url
err := s.CheckRedirectUrl(redirect)
if err != nil {
return "", err
}
// delete old ticket
err = s.deleteTicket(s.GetTicket(loginId))
if err != nil {
return "", err
}
// create new ticket
ticket, err := s.CreateTicket(loginId, client)
if err != nil {
return "", err
}
// return redirect + "?" + s.paramName.Ticket + "=" + ticket, nil
return util.AddQuery(s.encodeBackParam(redirect), s.paramName.Ticket, ticket), nil
}
// encodeBackParam find back param from url, and encode back param.
func (s *SsoEnforcer) encodeBackParam(url string) string {
// get back location
index := strings.Index(url, "?"+s.paramName.Back+"=")
if index == -1 {
index = strings.Index(url, "&"+s.paramName.Back+"=")
if index == -1 {
return url
}
}
// encode
length := len(s.paramName.Back) + 2
back := url[index+length:]
back = util.Encode(back)
// update back
url = url[:index+length] + back
return url
}
// buildCheckTicketUrl build to check ticket.
func (s *SsoEnforcer) buildCheckTicketUrl(ticket string, ssoLogoutCallUrl string) (string, error) {
if ticket == "" {
return "", errors.New("buildCheckTicketUrl() ticket can not be nil")
}
checkTicketUrl := s.config.SpliceCheckTicketUrl()
client := s.config.Client
paramMap := make(map[string]string)
if client != "" {
paramMap[s.paramName.Client] = client
}
paramMap[s.paramName.Ticket] = ticket
if ssoLogoutCallUrl != "" {
paramMap[s.paramName.SsoLogoutCall] = ssoLogoutCallUrl
}
return util.AddQueryMap(checkTicketUrl, paramMap), nil
}
// buildSloUrl build single-logout url.
func (s *SsoEnforcer) buildSloUrl(loginId string) (string, error) {
sloUrl := s.config.SpliceSloUrl()
url, err := s.joinLoginIdAndSign(sloUrl, loginId)
if err != nil {
return "", err
}
return url, nil
}
// buildGetDataUrl build getData url with sign,timestamp,nonce.
func (s *SsoEnforcer) buildGetDataUrl(paramMap map[string]string) (string, error) {
getDataUrl := s.config.SpliceGetDataUrl()
return s.buildCustomPathUrl(getDataUrl, paramMap)
}
// buildCustomPathUrl add paramMap to path.
func (s *SsoEnforcer) buildCustomPathUrl(path string, paramMap map[string]string) (string, error) {
u := path
if !strings.HasPrefix(u, "http") {
serverUrl := s.config.ServerUrl
if serverUrl == "" {
return "", errors.New("please set sso serverUrl")
}
u = util.SpliceUrl(serverUrl, path)
}
// sign map
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
paramMap[s.paramName.TimeStamp] = timestamp
nonce, err := util.GenerateRandomString32()
if err != nil {
return "", err
}
paramMap[s.paramName.Nonce] = nonce
// create sign
sign, err := s.createSign(timestamp, nonce)
if err != nil {
return "", err
}
paramMap[s.paramName.Sign] = sign
finalUrl := util.AddQueryMap(u, paramMap)
return finalUrl, nil
}
// request send http request and use json.Unmarshal to converted to *model.Result.
func (s *SsoEnforcer) request(url string) (*model.Result, error) {
resp, err := s.config.SendHttp(url)
log.Printf("http request response: %s", resp)
if err != nil {
return nil, err
}
result := &model.Result{}
err = json.Unmarshal([]byte(resp), result)
if err != nil {
return nil, err
}
return result, nil
}
func (s *SsoEnforcer) GetAllowUrl() string {
return s.config.AllowUrl
}
// saveTicket save ticket-id+client.
func (s *SsoEnforcer) saveTicket(ticket string, loginId string, client string) error {
value := loginId
if client != "" {
value += "," + client
}
ticketTimeout := s.config.TicketTimeout
return s.enforcer.GetAdapter().SetStr(s.spliceTicketSaveKey(ticket), value, ticketTimeout)
}
// saveTicketIndex save id-ticket.
func (s *SsoEnforcer) saveTicketIndex(ticket string, id string) error {
ticketTimeout := s.config.TicketTimeout
return s.enforcer.GetAdapter().SetStr(s.spliceTicketIndexKey(id), ticket, ticketTimeout)
}
// spliceTicketSaveKey splice ticket-id key.
func (s *SsoEnforcer) spliceTicketSaveKey(ticket string) string {
return s.enforcer.GetTokenConfig().TokenName + ":ticket:" + ticket
}
// spliceTicketIndexKey splice id-ticket key.
func (s *SsoEnforcer) spliceTicketIndexKey(id string) string {
return s.enforcer.GetTokenConfig().TokenName + ":id-ticket:" + id
}
func (s *SsoEnforcer) deleteTicket(ticket string) error {
if ticket == "" {
return nil
}
return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketSaveKey(ticket))
}
func (s *SsoEnforcer) deleteTicketIndex(id string) error {
if id == "" {
return nil
}
return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketIndexKey(id))
}
// checkTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
func (s *SsoEnforcer) checkTimeStamp(timestamp string) error {
parseInt, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return err
}
if !s.isValidTimeStamp(parseInt) {
return errors.New("timestamp is out of allowed range: " + timestamp)
}
return nil
}
// isValidTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
func (s *SsoEnforcer) isValidTimeStamp(timestamp int64) bool {
allowDisparity := s.signConfig.TimeStampDisparity
nowDisparity := time.Now().UnixMilli() - timestamp
return allowDisparity == 1 || nowDisparity <= allowDisparity
}
// checkNonce the same nonce can only be verified once.
func (s *SsoEnforcer) checkNonce(nonce string) error {
if nonce == "" {
return errors.New("nonce is nil")
}
// if nonce exists in adapter
if !s.isValidNonce(nonce) {
return errors.New("the nonce has been used: " + nonce)
}
// set nonce after nonce
err := s.enforcer.GetAdapter().SetStr(s.spliceNonceSaveKey(nonce), nonce, s.signConfig.GetSaveNonceExpire()*2+2)
if err != nil {
return err
}
return nil
}
// isValidNonce determine random string, if not exist in adapter, return true.
func (s *SsoEnforcer) isValidNonce(nonce string) bool {
if nonce == "" {
return false
}
key := s.spliceNonceSaveKey(nonce)
// if not exist in adapter, return true.
return s.enforcer.GetAdapter().GetStr(key) == ""
}
func (s *SsoEnforcer) checkSign(timestamp string, nonce string, sign string) error {
valid, err := s.isValidSign(timestamp, nonce, sign)
if err != nil {
return err
}
if !valid {
return errors.New("invalid sign: " + sign)
}
return nil
}
// isValidSign use timestamp,nonce,sign to createSign to compare.
func (s *SsoEnforcer) isValidSign(timestamp string, nonce string, sign string) (bool, error) {
recreateSign, err := s.createSign(timestamp, nonce)
if err != nil {
return false, err
}
return recreateSign == sign, nil
}
// createSign use util.MD5() to generate str.
func (s *SsoEnforcer) createSign(timestamp string, nonce string) (string, error) {
secretKey := s.signConfig.SecretKey
if secretKey == "" {
return "", errors.New("please check SignConfig.SecretKey, SecretKey can not be nil")
}
str := s.paramName.Nonce + "=" + nonce + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.SecretKet + "=" + secretKey
return util.MD5(str), nil
}
// spliceNonceSaveKey splice nonce store key.
func (s *SsoEnforcer) spliceNonceSaveKey(nonce string) string {
return s.enforcer.GetTokenConfig().TokenName + ":sign:nonce:" + nonce
}

135
sso/sso_server_api.go Normal file
View File

@@ -0,0 +1,135 @@
package sso
import (
"errors"
"github.com/weloe/token-go/constant"
"github.com/weloe/token-go/ctx"
"github.com/weloe/token-go/model"
)
/**
=========processor SSO-Server api
*/
// SsoAuth SSO-Server: auth.
func (s *SsoEnforcer) SsoAuth(ctx ctx.Context) (interface{}, error) {
request := ctx.Request()
response := ctx.Response()
isLogin, err := s.enforcer.IsLogin(ctx)
if err != nil {
return nil, err
}
// if you have not logged in to the SSO-Server, need to log in first
if !isLogin {
return s.config.NotLoginView(), nil
}
// if you have logged, check the mode
mode := request.Query(s.paramName.Mode)
redirect := request.Query(s.paramName.Redirect)
// if mode == simple, redirect to client directly
if mode == constant.MODE_SIMPLE {
err = s.CheckRedirectUrl(redirect)
if err != nil {
return nil, err
}
response.Redirect(redirect)
return nil, nil
} else {
// mode = ticket, redirect to client login with new ticket
id, err := s.enforcer.GetLoginId(ctx)
if err != nil {
return nil, err
}
redirectUrl, err := s.buildRedirectUrl(id, request.Query(s.paramName.Client), redirect)
if err != nil {
return nil, err
}
response.Redirect(redirectUrl)
return nil, nil
}
}
// SsoDoLogin SSO-Server: rest login api.
func (s *SsoEnforcer) SsoDoLogin(ctx ctx.Context) (interface{}, error) {
request := ctx.Request()
paramName := s.paramName
if s.config.DoLoginHandle == nil {
return nil, errors.New("SsoConfig.DoLoginHandle is nil")
}
resp, err := s.config.DoLoginHandle(request.Query(paramName.Name), request.Query(paramName.Pwd), ctx)
if err != nil {
return nil, err
}
return resp, nil
}
// SsoCheckTicket SSO-Server: check ticket to get loginId, returns loginId
func (s *SsoEnforcer) SsoCheckTicket(ctx ctx.Context) (interface{}, error) {
paramName := s.paramName
request := ctx.Request()
client := request.Query(paramName.Client)
ticket := request.Query(paramName.Ticket)
if ticket == "" {
return nil, errors.New("ticket can not be nil")
}
sloCallback := request.Query(paramName.SsoLogoutCall)
// check ticket
loginId, err := s.CheckTicketByClient(ticket, client)
if err != nil {
return nil, err
}
// register single sign out callback url
err = s.RegisterSloCallbackUrl(loginId, sloCallback)
if err != nil {
return nil, err
}
if loginId == "" {
return nil, errors.New("invalid ticket: " + ticket)
}
return model.Ok().SetData(loginId), nil
}
// SsoSignOut SSO-Server: single sign-out.
func (s *SsoEnforcer) SsoSignOut(ctx ctx.Context) (interface{}, error) {
request := ctx.Request()
paramName := s.paramName
// if enable single sign-out and request param has loginId
reqLoginId := request.Query(paramName.LoginId)
if s.config.IsSlo && reqLoginId == "" {
loginId, err := s.enforcer.GetLoginId(ctx)
if err != nil {
return nil, err
}
if loginId != "" {
err = s.ssoSignOutById(loginId)
if err != nil {
return nil, err
}
// callback
return s.ssoLogoutBack(ctx)
}
}
// if enable http,single sign-out and request param has loginId
if s.config.IsHttp && s.config.IsSlo && reqLoginId != "" {
err := s.checkRequest(request)
if err != nil {
return nil, err
}
// Use loginId to get single sign-out urls from session, traverse the urls to send the request to notify client
err = s.ssoSignOutById(reqLoginId)
if err != nil {
return nil, err
}
return model.Ok().SetMsg("sso sign-out success"), nil
}
return nil, errors.New("not handle")
}

103
sso/sso_test.go Normal file
View File

@@ -0,0 +1,103 @@
package sso
import (
tokenGo "github.com/weloe/token-go"
"github.com/weloe/token-go/config"
"github.com/weloe/token-go/ctx"
"github.com/weloe/token-go/model"
"github.com/weloe/token-go/util"
"testing"
)
func SendGetRequest(url string) (string, error) {
return util.SendGetRequest(url)
}
func TestNewSsoServerEnforcer(t *testing.T) {
var err error
// use default adapter
adapter := tokenGo.NewDefaultAdapter()
enforcer, err := tokenGo.NewEnforcer(adapter)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
// enable logger
enforcer.EnableLog()
ssoOptions := &config.SsoOptions{
Mode: "",
TicketTimeout: 300,
AllowUrl: "*",
IsSlo: true,
IsHttp: true,
ServerUrl: "http://token-go-sso-server.com:9000",
NotLoginView: func() interface{} {
msg := "not log in SSO-Server, please visit <a href='/sso/doLogin?name=tokengo&pwd=123456' target='_blank'> doLogin </a>"
return msg
},
DoLoginHandle: func(name string, pwd string, ctx ctx.Context) (interface{}, error) {
if name != "tokengo" {
return "name error", nil
}
if pwd != "123456" {
return "pwd error", nil
}
token, err := enforcer.Login("1001", ctx)
if err != nil {
return nil, err
}
return model.Ok().SetData(token), nil
},
SendHttp: func(url string) (string, error) {
return SendGetRequest(url)
},
}
signOptions := &config.SignOptions{
SecretKey: "kQwIOrYvnXmSDkwEiFngrKidMcdrgKor",
IsCheckNonce: true,
}
ssoEnforcer, err := NewSsoEnforcer(&Options{
SsoOptions: ssoOptions,
SignOptions: signOptions,
Enforcer: enforcer,
})
if err != nil {
t.Errorf("NewSsoEnforcer() failed: %v", err)
}
t.Logf("enforcer: %v", ssoEnforcer)
}
func TestNewSsoClient3Enforcer(t *testing.T) {
var err error
// use default adapter
adapter := tokenGo.NewDefaultAdapter()
enforcer, err := tokenGo.NewEnforcer(adapter)
if err != nil {
t.Errorf("NewEnforcer() failed: %v", err)
}
// enable logger
enforcer.EnableLog()
ssoOptions := &config.SsoOptions{
AuthUrl: "/sso/auth",
IsSlo: true,
IsHttp: true,
SloUrl: "/sso/signout",
CheckTicketUrl: "/sso/checkTicket",
ServerUrl: "http://token-go-sso-server.com:9000",
SendHttp: func(url string) (string, error) {
return SendGetRequest(url)
},
}
signOptions := &config.SignOptions{
SecretKey: "kQwIOrYvnXmSDkwEiFngrKidMcdrgKor",
IsCheckNonce: true,
}
ssoEnforcer, err := NewSsoEnforcer(&Options{
SsoOptions: ssoOptions,
SignOptions: signOptions,
Enforcer: enforcer,
})
if err != nil {
t.Errorf("NewSsoEnforcer() failed: %v", err)
}
t.Logf("enforcer: %v", ssoEnforcer)
}

13
util/secure.go Normal file
View File

@@ -0,0 +1,13 @@
package util
import (
"crypto/md5"
"fmt"
)
func MD5(str string) string {
data := []byte(str)
has := md5.Sum(data)
md5str := fmt.Sprintf("%x", has)
return md5str
}

12
util/sign.go Normal file
View File

@@ -0,0 +1,12 @@
package util
import (
"time"
)
// IsValidTimeStamp determine whether the gap between the startTime and the current timestamp is within the allowable range.
func IsValidTimeStamp(startTime int64, allowDisparity int64) bool {
nowDisparity := time.Now().UnixMilli() - startTime
return allowDisparity == 1 || nowDisparity <= allowDisparity
}

137
util/url.go Normal file
View File

@@ -0,0 +1,137 @@
package util
import (
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
)
func SendGetRequest(url string) (string, error) {
response, err := http.Get(url)
if err != nil {
log.Printf("http.Get() failed: %v", err)
return "", err
}
defer func(Body io.ReadCloser) {
err = Body.Close()
if err != nil {
log.Printf("read response body failed: %v", err)
}
}(response.Body)
body, err := ioutil.ReadAll(response.Body)
if err != nil {
log.Printf("ioutil.ReadAll() failed: %v", err)
return "", err
}
return string(body), nil
}
// SpliceUrl splice two url.
// Examples:
// u1 = "http://domain.com" u2 = "/sso/auth" return http://domain.com/sso/auth
func SpliceUrl(u1 string, u2 string) string {
if u1 == "" {
return u2
}
if u2 == "" {
return u1
}
if strings.HasPrefix(u2, "http") {
return u2
}
return u1 + u2
}
func HasUrl(urls []string, url string) bool {
for _, s := range urls {
if MatchUrl(s, url) {
return true
}
}
return false
}
func MatchUrl(pattern string, url string) bool {
if pattern == "*" {
return true
}
return pattern == url
}
func IsValidUrl(u1 string) bool {
_, err := url.ParseRequestURI(u1)
if err != nil {
return false
}
u, err := url.Parse(u1)
if err != nil || u.Scheme == "" || u.Host == "" {
return false
}
// check if the URL has a valid scheme (http or https)
if u.Scheme != "http" && u.Scheme != "https" {
return false
}
return true
}
// AddQueryMap add map param for the path.
func AddQueryMap(path string, paramMap map[string]string) string {
queryString := MapToQuery(paramMap)
return AddQueryValue(path, queryString)
}
func AddQueryValue(path string, queryString string) string {
index := strings.LastIndex(path, "?")
// if the path is not included "?"
if index == -1 {
return path + "?" + queryString
}
// if the last is "?"
if index == len(path)-1 {
return path + queryString
}
// if "?" inside path, the last is not "&" and queryString's first string is not "&"
if index < len(path)-1 {
if strings.LastIndex(path, "&") != len(path)-1 && strings.Index(path, "&") != 0 {
return path + "&" + queryString
} else {
return path + queryString
}
}
return path
}
// AddQuery add query param for the path.
func AddQuery(path string, key string, value string) string {
queryString := key + "=" + value
return AddQueryValue(path, queryString)
}
// MapToQuery convert map to k=v array, and use "&" to join.
func MapToQuery(paramMap map[string]string) string {
var queryString []string
for k, v := range paramMap {
queryString = append(queryString, k+"="+v)
}
query := strings.Join(queryString, "&")
return query
}
func Encode(u string) string {
return url.QueryEscape(u)
}

144
util/url_test.go Normal file
View File

@@ -0,0 +1,144 @@
package util
import (
"log"
"testing"
)
func TestSpliceUrl(t *testing.T) {
type args struct {
u1 string
u2 string
}
tests := []struct {
name string
args args
want string
}{
{
name: "success1",
args: args{
u1: "http://domain.com",
u2: "/sso/auth",
},
want: "http://domain.com/sso/auth",
},
{
name: "success2",
args: args{
u1: "",
u2: "http://domain.com/sso/auth",
},
want: "http://domain.com/sso/auth",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := SpliceUrl(tt.args.u1, tt.args.u2); got != tt.want {
t.Errorf("SpliceUrl() = %v, want %v", got, tt.want)
}
})
}
}
func TestHasUrl(t *testing.T) {
strings := []string{"1", "2", "3"}
hasUrl := HasUrl(strings, "2")
if !hasUrl {
t.Errorf("HasUrl() = %v, want %v", false, true)
}
}
func TestMatchUrl(t *testing.T) {
all := "*"
if !MatchUrl(all, "123") {
t.Errorf("HasUrl() = %v, want %v", false, true)
}
}
func TestIsValidUrl(t *testing.T) {
type args struct {
u1 string
}
tests := []struct {
name string
args args
want bool
}{
{
name: "error url1",
args: args{u1: "htp:23/asd"},
want: false,
},
{
name: "validated",
args: args{u1: "http://123.com:90//ac"},
want: true,
},
{
name: "error url2",
args: args{u1: "http:/123.com:90//ac"},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsValidUrl(tt.args.u1); got != tt.want {
t.Errorf("IsValidUrl() = %v, want %v", got, tt.want)
}
})
}
}
func TestAddQueryMap(t *testing.T) {
m := make(map[string]string)
m["1"] = "2"
m["2"] = "3"
if got := AddQueryMap("/sso/auth", m); got != "/sso/auth?1=2&2=3" && got != "/sso/auth?2=3&1=2" {
t.Errorf("AddQueryMap() = %v, want %v", got, "/sso/auth?1=2&2=3")
}
if got := AddQueryMap("/sso/auth?", m); got != "/sso/auth?1=2&2=3" && got != "/sso/auth?2=3&1=2" {
t.Errorf("AddQueryMap() = %v, want %v", got, "/sso/auth?1=2&2=3")
}
}
func TestAddQuery(t *testing.T) {
if got := AddQuery("/sso", "1", "2"); got != "/sso?1=2" {
t.Errorf("AddQuery() = %v, want %v", got, "/sso?1=2")
}
if got := AddQuery("/sso?", "1", "2"); got != "/sso?1=2" {
t.Errorf("AddQuery() = %v, want %v", got, "/sso?1=2")
}
}
func TestMapToQuery(t *testing.T) {
m := make(map[string]string)
m["1"] = "2"
query := MapToQuery(m)
if query != "1=2" {
t.Errorf("MapToQuery() = %v, want %v", query, "1=2")
}
}
func TestEncode(t *testing.T) {
encode := Encode("abc123==123")
log.Print("Encode(\"abc123==123\") = " + encode)
}
func TestAddQueryValue(t *testing.T) {
if got := AddQueryValue("/sso/auth?back=http://123.com/login", "ticket=23324"); got != "/sso/auth?back=http://123.com/login&ticket=23324" {
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login&ticket=23324")
}
if got := AddQueryValue("/sso/auth?back=http://123.com/login?", "ticket=23324"); got != "/sso/auth?back=http://123.com/login?ticket=23324" {
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=23324")
}
if got := AddQueryValue("/sso/auth?back=http://123.com/login?ticket=123", "redirect=23324"); got != "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324" {
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324")
}
if got := AddQueryValue("/sso/auth?back=http://123.com/login?ticket=123&", "redirect=23324"); got != "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324" {
t.Errorf("AddQueryValue() = %v, want %v", got, "/sso/auth?back=http://123.com/login?ticket=123&redirect=23324")
}
}

View File

@@ -26,3 +26,12 @@ func InterfaceToBytes(data interface{}) ([]byte, error) {
} }
return nil, fmt.Errorf("unable to convert %T to []byte", data) return nil, fmt.Errorf("unable to convert %T to []byte", data)
} }
// AppendStr do not add repeated str.
// If old slice has newStr, return directly, else append
func AppendStr(old []string, newStr string) []string {
if HasStr(old, newStr) {
return old
}
return append(old, newStr)
}

22
util/util_test.go Normal file
View File

@@ -0,0 +1,22 @@
package util
import (
"testing"
)
func TestAppendStr(t *testing.T) {
strings := []string{"1", "2"}
str := AppendStr(strings, "1")
str = AppendStr(str, "3")
for i, s := range strings {
if s == "1" && i == 2 {
t.Errorf("AppendStr() = %v, want %v", str, []string{"1", "2"})
}
}
for i, s := range strings {
if i == 2 && s != "3" {
t.Errorf("AppendStr() = %v, want %v", str, []string{"1", "2", "3"})
}
}
}