mirror of
https://github.com/weloe/token-go.git
synced 2025-10-05 23:46:52 +08:00
502 lines
14 KiB
Go
502 lines
14 KiB
Go
package sso
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/weloe/token-go/constant"
|
|
"github.com/weloe/token-go/ctx"
|
|
"github.com/weloe/token-go/model"
|
|
"github.com/weloe/token-go/util"
|
|
"log"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
/**
|
|
=========internal api
|
|
*/
|
|
|
|
// checkRequest check request param timestamp,nonce,sign.
|
|
func (s *SsoEnforcer) checkRequest(request ctx.Request) error {
|
|
timestamp := request.Query(s.paramName.TimeStamp)
|
|
nonce := request.Query(s.paramName.Nonce)
|
|
sign := request.Query(s.paramName.Sign)
|
|
err := s.checkTimeStamp(timestamp)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if s.signConfig.IsCheckNonce {
|
|
err = s.checkNonce(nonce)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
err = s.checkSign(timestamp, nonce, sign)
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CreateTicket create ticket by account-id.
|
|
func (s *SsoEnforcer) CreateTicket(loginId string, client string) (string, error) {
|
|
// create random string ticket
|
|
ticket, err := util.GenerateRandomString64()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// save ticket-id+client
|
|
err = s.saveTicket(ticket, loginId, client)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// save id-ticket
|
|
err = s.saveTicketIndex(ticket, loginId)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return ticket, nil
|
|
}
|
|
|
|
// GetLoginId get loginId by ticket.
|
|
func (s *SsoEnforcer) GetLoginId(ticket string) string {
|
|
if ticket == "" {
|
|
return ""
|
|
}
|
|
loginId := s.getLoginIdByTicket(ticket)
|
|
if loginId != "" && strings.Contains(loginId, ",") {
|
|
split := strings.Split(loginId, ",")
|
|
loginId = split[0]
|
|
}
|
|
return loginId
|
|
}
|
|
|
|
func (s *SsoEnforcer) getLoginIdByTicket(ticket string) string {
|
|
loginId := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
|
|
return loginId
|
|
}
|
|
|
|
// GetTicket get ticket by loginId.
|
|
func (s *SsoEnforcer) GetTicket(loginId string) string {
|
|
if loginId == "" {
|
|
return ""
|
|
}
|
|
return s.enforcer.GetAdapter().GetStr(s.spliceTicketIndexKey(loginId))
|
|
}
|
|
|
|
// CheckTicket use config.Client to check ticket,return loginId.
|
|
func (s *SsoEnforcer) CheckTicket(ticket string) (string, error) {
|
|
return s.CheckTicketByClient(ticket, s.config.Client)
|
|
}
|
|
|
|
// CheckTicketByClient check ticket by pointing client,return loginId.
|
|
func (s *SsoEnforcer) CheckTicketByClient(ticket string, client string) (string, error) {
|
|
id := s.getLoginIdByTicket(ticket)
|
|
if id == "" {
|
|
return "", nil
|
|
}
|
|
|
|
// get client from id
|
|
var ticketClient string
|
|
if strings.Contains(id, ",") {
|
|
split := strings.Split(id, ",")
|
|
id = split[0]
|
|
ticketClient = split[1]
|
|
}
|
|
|
|
if client != "" && client != ticketClient {
|
|
return "", fmt.Errorf("the ticket does not belong to the client, client: %v, ticket: %v", client, ticket)
|
|
}
|
|
err := s.deleteTicket(ticket)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
err = s.deleteTicketIndex(id)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return id, nil
|
|
}
|
|
|
|
// CheckRedirectUrl check redirectUrl.
|
|
func (s *SsoEnforcer) CheckRedirectUrl(url string) error {
|
|
if !util.IsValidUrl(url) {
|
|
return fmt.Errorf("invalid redirect url: %v", url)
|
|
}
|
|
index := strings.Index(url, "?")
|
|
if index != -1 {
|
|
url = url[0:index]
|
|
}
|
|
allowUrls := strings.Split(s.GetAllowUrl(), ",")
|
|
|
|
if !util.HasUrl(allowUrls, url) {
|
|
return fmt.Errorf("illegal redirect url: %v", url)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RegisterSloCallbackUrl register the URL of the single logout callback for the account id.
|
|
func (s *SsoEnforcer) RegisterSloCallbackUrl(loginId string, sloCallbackUrl string) error {
|
|
if loginId == "" || sloCallbackUrl == "" {
|
|
return nil
|
|
}
|
|
session := s.enforcer.GetSession(loginId)
|
|
// splice session id
|
|
sessionId := s.enforcer.GetTokenConfig().TokenName + ":" + s.enforcer.GetType() + ":session:" + loginId
|
|
if session == nil {
|
|
session = model.NewSession(sessionId, "account-session", loginId)
|
|
}
|
|
value := session.Get(constant.SLO_CALLBACK_SET_KEY)
|
|
|
|
var v []string
|
|
if value != nil {
|
|
sv, ok := value.([]string)
|
|
if !ok {
|
|
return errors.New("session SLO_CALLBACK_SET_KEY_ data convert into []string failed")
|
|
}
|
|
v = sv
|
|
}
|
|
v = util.AppendStr(v, sloCallbackUrl)
|
|
|
|
session.Set(sessionId, v)
|
|
// update session
|
|
err := s.enforcer.UpdateSession(loginId, session)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ssoSignOutById single sign-out of the specified account.
|
|
// Use loginId to get single sign-out urls from session.
|
|
func (s *SsoEnforcer) ssoSignOutById(loginId string) error {
|
|
// if loginId is not logged, return error
|
|
session := s.enforcer.GetSession(loginId)
|
|
if session == nil {
|
|
return errors.New("this loginId is not logged in")
|
|
}
|
|
value := session.Get(constant.SLO_CALLBACK_SET_KEY)
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
urls, ok := value.([]string)
|
|
if !ok {
|
|
return errors.New("convert into []string failed")
|
|
}
|
|
// range urls to make client logout
|
|
for _, url := range urls {
|
|
// join url
|
|
newUrl, err := s.joinLoginIdAndSign(url, loginId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// sent http
|
|
_, err = s.config.SendHttp(newUrl)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// server logout
|
|
return s.enforcer.LogoutById(loginId)
|
|
}
|
|
|
|
// joinLoginIdAndSign splice the loginId to the url, and stitching parameters such as sign.
|
|
func (s *SsoEnforcer) joinLoginIdAndSign(url string, id string) (string, error) {
|
|
nonce, err := util.GenerateRandomString32()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
|
sign, err := s.createSign(timestamp, nonce)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
str := url + "?" + s.paramName.LoginId + "=" + id + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.Nonce + "=" + nonce + "&" + s.paramName.Sign + "=" + sign
|
|
return str, nil
|
|
}
|
|
|
|
// buildServerAuthUrl SSO-Client build SSO-Server single sign-on url.
|
|
func (s *SsoEnforcer) buildServerAuthUrl(clientLoginUrl string, back string) (string, error) {
|
|
if clientLoginUrl == "" {
|
|
return "", errors.New("arg[0] clientLoginUrl can not be nil")
|
|
}
|
|
|
|
// get server auth url
|
|
authUrl := s.config.SpliceAuthUrl()
|
|
|
|
client := s.config.Client
|
|
|
|
if client != "" {
|
|
authUrl = util.AddQuery(authUrl, s.paramName.Client, client)
|
|
}
|
|
|
|
// splice back url
|
|
if back != "" {
|
|
back = util.Encode(back)
|
|
clientLoginUrl = util.AddQuery(clientLoginUrl, s.paramName.Back, back)
|
|
}
|
|
|
|
return util.AddQuery(authUrl, s.paramName.Redirect, clientLoginUrl), nil
|
|
}
|
|
|
|
// buildRedirectUrl the server gives the redirectUrl of the ticket to the client.
|
|
// Check redirect url, delete old ticket of loginId and create new ticket, then return url with new ticket.
|
|
func (s *SsoEnforcer) buildRedirectUrl(loginId string, client string, redirect string) (string, error) {
|
|
// check redirect url
|
|
err := s.CheckRedirectUrl(redirect)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// delete old ticket
|
|
err = s.deleteTicket(s.GetTicket(loginId))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
// create new ticket
|
|
ticket, err := s.CreateTicket(loginId, client)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// return redirect + "?" + s.paramName.Ticket + "=" + ticket, nil
|
|
return util.AddQuery(s.encodeBackParam(redirect), s.paramName.Ticket, ticket), nil
|
|
}
|
|
|
|
// encodeBackParam find back param from url, and encode back param.
|
|
func (s *SsoEnforcer) encodeBackParam(url string) string {
|
|
// get back location
|
|
index := strings.Index(url, "?"+s.paramName.Back+"=")
|
|
if index == -1 {
|
|
index = strings.Index(url, "&"+s.paramName.Back+"=")
|
|
if index == -1 {
|
|
return url
|
|
}
|
|
}
|
|
|
|
// encode
|
|
length := len(s.paramName.Back) + 2
|
|
back := url[index+length:]
|
|
back = util.Encode(back)
|
|
|
|
// update back
|
|
url = url[:index+length] + back
|
|
return url
|
|
}
|
|
|
|
// buildCheckTicketUrl build to check ticket.
|
|
func (s *SsoEnforcer) buildCheckTicketUrl(ticket string, ssoLogoutCallUrl string) (string, error) {
|
|
if ticket == "" {
|
|
return "", errors.New("buildCheckTicketUrl() ticket can not be nil")
|
|
}
|
|
checkTicketUrl := s.config.SpliceCheckTicketUrl()
|
|
client := s.config.Client
|
|
paramMap := make(map[string]string)
|
|
if client != "" {
|
|
paramMap[s.paramName.Client] = client
|
|
}
|
|
paramMap[s.paramName.Ticket] = ticket
|
|
if ssoLogoutCallUrl != "" {
|
|
paramMap[s.paramName.SsoLogoutCall] = ssoLogoutCallUrl
|
|
}
|
|
|
|
return util.AddQueryMap(checkTicketUrl, paramMap), nil
|
|
}
|
|
|
|
// buildSloUrl build single-logout url.
|
|
func (s *SsoEnforcer) buildSloUrl(loginId string) (string, error) {
|
|
sloUrl := s.config.SpliceSloUrl()
|
|
url, err := s.joinLoginIdAndSign(sloUrl, loginId)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return url, nil
|
|
}
|
|
|
|
// buildGetDataUrl build getData url with sign,timestamp,nonce.
|
|
func (s *SsoEnforcer) buildGetDataUrl(paramMap map[string]string) (string, error) {
|
|
getDataUrl := s.config.SpliceGetDataUrl()
|
|
return s.buildCustomPathUrl(getDataUrl, paramMap)
|
|
}
|
|
|
|
// buildCustomPathUrl add paramMap to path.
|
|
func (s *SsoEnforcer) buildCustomPathUrl(path string, paramMap map[string]string) (string, error) {
|
|
u := path
|
|
|
|
if !strings.HasPrefix(u, "http") {
|
|
serverUrl := s.config.ServerUrl
|
|
if serverUrl == "" {
|
|
return "", errors.New("please set sso serverUrl")
|
|
}
|
|
u = util.SpliceUrl(serverUrl, path)
|
|
}
|
|
// sign map
|
|
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
|
paramMap[s.paramName.TimeStamp] = timestamp
|
|
nonce, err := util.GenerateRandomString32()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
paramMap[s.paramName.Nonce] = nonce
|
|
// create sign
|
|
sign, err := s.createSign(timestamp, nonce)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
paramMap[s.paramName.Sign] = sign
|
|
finalUrl := util.AddQueryMap(u, paramMap)
|
|
return finalUrl, nil
|
|
}
|
|
|
|
// request send http request and use json.Unmarshal to converted to *model.Result.
|
|
func (s *SsoEnforcer) request(url string) (*model.Result, error) {
|
|
resp, err := s.config.SendHttp(url)
|
|
|
|
log.Printf("http request response: %s", resp)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result := &model.Result{}
|
|
err = json.Unmarshal([]byte(resp), result)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func (s *SsoEnforcer) GetAllowUrl() string {
|
|
return s.config.AllowUrl
|
|
}
|
|
|
|
// saveTicket save ticket-id+client.
|
|
func (s *SsoEnforcer) saveTicket(ticket string, loginId string, client string) error {
|
|
value := loginId
|
|
if client != "" {
|
|
value += "," + client
|
|
}
|
|
ticketTimeout := s.config.TicketTimeout
|
|
return s.enforcer.GetAdapter().SetStr(s.spliceTicketSaveKey(ticket), value, ticketTimeout)
|
|
}
|
|
|
|
// delete ticket - id,client
|
|
func (s *SsoEnforcer) deleteTicket(ticket string) error {
|
|
if ticket == "" {
|
|
return nil
|
|
}
|
|
return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketSaveKey(ticket))
|
|
}
|
|
|
|
// spliceTicketSaveKey splice ticket-id,client key.
|
|
func (s *SsoEnforcer) spliceTicketSaveKey(ticket string) string {
|
|
return s.enforcer.GetTokenConfig().TokenName + ":ticket:" + ticket
|
|
}
|
|
|
|
// saveTicketIndex save id-ticket.
|
|
func (s *SsoEnforcer) saveTicketIndex(ticket string, id string) error {
|
|
ticketTimeout := s.config.TicketTimeout
|
|
return s.enforcer.GetAdapter().SetStr(s.spliceTicketIndexKey(id), ticket, ticketTimeout)
|
|
}
|
|
|
|
func (s *SsoEnforcer) deleteTicketIndex(id string) error {
|
|
if id == "" {
|
|
return nil
|
|
}
|
|
return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketIndexKey(id))
|
|
}
|
|
|
|
// spliceTicketIndexKey splice id-ticket key.
|
|
func (s *SsoEnforcer) spliceTicketIndexKey(id string) string {
|
|
return s.enforcer.GetTokenConfig().TokenName + ":id-ticket:" + id
|
|
}
|
|
|
|
// checkTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
|
|
func (s *SsoEnforcer) checkTimeStamp(timestamp string) error {
|
|
parseInt, err := strconv.ParseInt(timestamp, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !s.isValidTimeStamp(parseInt) {
|
|
return errors.New("timestamp is out of allowed range: " + timestamp)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// isValidTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
|
|
func (s *SsoEnforcer) isValidTimeStamp(timestamp int64) bool {
|
|
allowDisparity := s.signConfig.TimeStampDisparity
|
|
nowDisparity := time.Now().UnixMilli() - timestamp
|
|
|
|
return allowDisparity == 1 || nowDisparity <= allowDisparity
|
|
}
|
|
|
|
// checkNonce the same nonce can only be verified once, cannot be used again for a period of time after use
|
|
func (s *SsoEnforcer) checkNonce(nonce string) error {
|
|
if nonce == "" {
|
|
return errors.New("nonce is nil")
|
|
}
|
|
// if nonce exists in adapter
|
|
if !s.isValidNonce(nonce) {
|
|
return errors.New("the nonce has been used: " + nonce)
|
|
}
|
|
// set nonce after nonce
|
|
err := s.enforcer.GetAdapter().SetStr(s.spliceNonceSaveKey(nonce), nonce, s.signConfig.GetSaveNonceExpire()*2+2)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// isValidNonce determine random string, if not exist in adapter, return true.
|
|
func (s *SsoEnforcer) isValidNonce(nonce string) bool {
|
|
if nonce == "" {
|
|
return false
|
|
}
|
|
key := s.spliceNonceSaveKey(nonce)
|
|
// if not exist in adapter, return true.
|
|
return s.enforcer.GetAdapter().GetStr(key) == ""
|
|
}
|
|
|
|
func (s *SsoEnforcer) checkSign(timestamp string, nonce string, sign string) error {
|
|
valid, err := s.isValidSign(timestamp, nonce, sign)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !valid {
|
|
return errors.New("invalid sign: " + sign)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// isValidSign use timestamp,nonce,sign to createSign to compare.
|
|
func (s *SsoEnforcer) isValidSign(timestamp string, nonce string, sign string) (bool, error) {
|
|
recreateSign, err := s.createSign(timestamp, nonce)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return recreateSign == sign, nil
|
|
}
|
|
|
|
// createSign use util.MD5() to generate str.
|
|
func (s *SsoEnforcer) createSign(timestamp string, nonce string) (string, error) {
|
|
secretKey := s.signConfig.SecretKey
|
|
if secretKey == "" {
|
|
return "", errors.New("please check SignConfig.SecretKey, SecretKey can not be nil")
|
|
}
|
|
str := s.paramName.Nonce + "=" + nonce + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.SecretKet + "=" + secretKey
|
|
|
|
return util.MD5(str), nil
|
|
}
|
|
|
|
// spliceNonceSaveKey splice nonce store key.
|
|
func (s *SsoEnforcer) spliceNonceSaveKey(nonce string) string {
|
|
return s.enforcer.GetTokenConfig().TokenName + ":sign:nonce:" + nonce
|
|
}
|