mirror of
				https://github.com/weloe/token-go.git
				synced 2025-10-31 19:02:41 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			495 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			495 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package sso
 | |
| 
 | |
| import (
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"github.com/weloe/token-go/constant"
 | |
| 	"github.com/weloe/token-go/ctx"
 | |
| 	"github.com/weloe/token-go/model"
 | |
| 	"github.com/weloe/token-go/util"
 | |
| 	"log"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| )
 | |
| 
 | |
| /**
 | |
| =========internal api
 | |
| */
 | |
| 
 | |
| // checkRequest check request param timestamp,nonce,sign.
 | |
| func (s *SsoEnforcer) checkRequest(request ctx.Request) error {
 | |
| 	timestamp := request.Query(s.paramName.TimeStamp)
 | |
| 	nonce := request.Query(s.paramName.Nonce)
 | |
| 	sign := request.Query(s.paramName.Sign)
 | |
| 	err := s.checkTimeStamp(timestamp)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if s.signConfig.IsCheckNonce {
 | |
| 		err = s.checkNonce(nonce)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 	err = s.checkSign(timestamp, nonce, sign)
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // CreateTicket create ticket by account-id.
 | |
| func (s *SsoEnforcer) CreateTicket(loginId string, client string) (string, error) {
 | |
| 	// create random string ticket
 | |
| 	ticket, err := util.GenerateRandomString64()
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	// save ticket-id+client
 | |
| 	err = s.saveTicket(ticket, loginId, client)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	// save id-ticket
 | |
| 	err = s.saveTicketIndex(ticket, loginId)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	return ticket, nil
 | |
| }
 | |
| 
 | |
| // GetLoginId get loginId by ticket.
 | |
| func (s *SsoEnforcer) GetLoginId(ticket string) string {
 | |
| 	if ticket == "" {
 | |
| 		return ""
 | |
| 	}
 | |
| 	loginId := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
 | |
| 	if loginId != "" && strings.Contains(loginId, ",") {
 | |
| 		split := strings.Split(loginId, ",")
 | |
| 		loginId = split[0]
 | |
| 	}
 | |
| 	return loginId
 | |
| }
 | |
| 
 | |
| // GetTicket get ticket by loginId.
 | |
| func (s *SsoEnforcer) GetTicket(loginId string) string {
 | |
| 	if loginId == "" {
 | |
| 		return ""
 | |
| 	}
 | |
| 	return s.enforcer.GetAdapter().GetStr(s.spliceTicketIndexKey(loginId))
 | |
| }
 | |
| 
 | |
| // CheckTicket use config.Client to check ticket,return loginId.
 | |
| func (s *SsoEnforcer) CheckTicket(ticket string) (string, error) {
 | |
| 	return s.CheckTicketByClient(ticket, s.config.Client)
 | |
| }
 | |
| 
 | |
| // CheckTicketByClient check ticket by pointing client,return loginId.
 | |
| func (s *SsoEnforcer) CheckTicketByClient(ticket string, client string) (string, error) {
 | |
| 	id := s.enforcer.GetAdapter().GetStr(s.spliceTicketSaveKey(ticket))
 | |
| 	if id == "" {
 | |
| 		return "", nil
 | |
| 	}
 | |
| 
 | |
| 	var ticketClient string
 | |
| 	if strings.Contains(id, ",") {
 | |
| 		split := strings.Split(id, ",")
 | |
| 		id = split[0]
 | |
| 		ticketClient = split[1]
 | |
| 	}
 | |
| 
 | |
| 	if client != "" && client != ticketClient {
 | |
| 		return "", fmt.Errorf("the ticket does not belong to the client, client: %v, ticket: %v", client, ticket)
 | |
| 	}
 | |
| 	err := s.deleteTicket(ticket)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	err = s.deleteTicketIndex(id)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	return id, nil
 | |
| }
 | |
| 
 | |
| // CheckRedirectUrl check redirectUrl.
 | |
| func (s *SsoEnforcer) CheckRedirectUrl(url string) error {
 | |
| 	if !util.IsValidUrl(url) {
 | |
| 		return fmt.Errorf("invalid redirect url: %v", url)
 | |
| 	}
 | |
| 	index := strings.Index(url, "?")
 | |
| 	if index != -1 {
 | |
| 		url = url[0:index]
 | |
| 	}
 | |
| 	allowUrls := strings.Split(s.GetAllowUrl(), ",")
 | |
| 
 | |
| 	if !util.HasUrl(allowUrls, url) {
 | |
| 		return fmt.Errorf("illegal redirect url: %v", url)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // RegisterSloCallbackUrl register the URL of the single logout callback for the account id.
 | |
| func (s *SsoEnforcer) RegisterSloCallbackUrl(loginId string, sloCallbackUrl string) error {
 | |
| 	if loginId == "" || sloCallbackUrl == "" {
 | |
| 		return nil
 | |
| 	}
 | |
| 	session := s.enforcer.GetSession(loginId)
 | |
| 	// splice session id
 | |
| 	sessionId := s.enforcer.GetTokenConfig().TokenName + ":" + s.enforcer.GetType() + ":session:" + loginId
 | |
| 	if session == nil {
 | |
| 		session = model.NewSession(sessionId, "account-session", loginId)
 | |
| 	}
 | |
| 	value := session.Get(constant.SLO_CALLBACK_SET_KEY)
 | |
| 
 | |
| 	var v []string
 | |
| 	if value != nil {
 | |
| 		sv, ok := value.([]string)
 | |
| 		v = sv
 | |
| 		if !ok {
 | |
| 			return errors.New("session SLO_CALLBACK_SET_KEY_ data convert into []string failed")
 | |
| 		}
 | |
| 	}
 | |
| 	v = util.AppendStr(v, sloCallbackUrl)
 | |
| 
 | |
| 	session.Set(sessionId, v)
 | |
| 	// update session
 | |
| 	err := s.enforcer.UpdateSession(loginId, session)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // ssoSignOutById single sign-out of the specified account.
 | |
| // Use loginId to get single sign-out urls from session.
 | |
| func (s *SsoEnforcer) ssoSignOutById(loginId string) error {
 | |
| 	// if loginId is not logged, return error
 | |
| 	session := s.enforcer.GetSession(loginId)
 | |
| 	if session == nil {
 | |
| 		return errors.New("this loginId is not logged in")
 | |
| 	}
 | |
| 	value := session.Get(constant.SLO_CALLBACK_SET_KEY)
 | |
| 	if value == nil {
 | |
| 		return nil
 | |
| 	}
 | |
| 	urls, ok := value.([]string)
 | |
| 	if !ok {
 | |
| 		return errors.New("convert into []string failed")
 | |
| 	}
 | |
| 	// range urls to make client logout
 | |
| 	for _, url := range urls {
 | |
| 		// join url
 | |
| 		newUrl, err := s.joinLoginIdAndSign(url, loginId)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 		// sent http
 | |
| 		_, err = s.config.SendHttp(newUrl)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// server logout
 | |
| 	return s.enforcer.LogoutById(loginId)
 | |
| }
 | |
| 
 | |
| // joinLoginIdAndSign splice the loginId to the url, and stitching parameters such as sign.
 | |
| func (s *SsoEnforcer) joinLoginIdAndSign(url string, id string) (string, error) {
 | |
| 	nonce, err := util.GenerateRandomString32()
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
 | |
| 	sign, err := s.createSign(timestamp, nonce)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	str := url + "?" + s.paramName.LoginId + "=" + id + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.Nonce + "=" + nonce + "&" + s.paramName.Sign + "=" + sign
 | |
| 	return str, nil
 | |
| }
 | |
| 
 | |
| // buildServerAuthUrl SSO-Client build SSO-Server single sign-on url.
 | |
| func (s *SsoEnforcer) buildServerAuthUrl(clientLoginUrl string, back string) (string, error) {
 | |
| 	if clientLoginUrl == "" {
 | |
| 		return "", errors.New("arg[0] clientLoginUrl can not be nil")
 | |
| 	}
 | |
| 
 | |
| 	// get server auth url
 | |
| 	authUrl := s.config.SpliceAuthUrl()
 | |
| 
 | |
| 	client := s.config.Client
 | |
| 
 | |
| 	if client != "" {
 | |
| 		authUrl = util.AddQuery(authUrl, s.paramName.Client, client)
 | |
| 	}
 | |
| 
 | |
| 	// splice back url
 | |
| 	if back != "" {
 | |
| 		back = util.Encode(back)
 | |
| 		clientLoginUrl = util.AddQuery(clientLoginUrl, s.paramName.Back, back)
 | |
| 	}
 | |
| 
 | |
| 	return util.AddQuery(authUrl, s.paramName.Redirect, clientLoginUrl), nil
 | |
| }
 | |
| 
 | |
| // buildRedirectUrl the server gives the redirectUrl of the ticket to the client.
 | |
| // Check redirect url, delete old ticket of loginId and create new ticket, then return url with new ticket.
 | |
| func (s *SsoEnforcer) buildRedirectUrl(loginId string, client string, redirect string) (string, error) {
 | |
| 	// check redirect url
 | |
| 	err := s.CheckRedirectUrl(redirect)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	// delete old ticket
 | |
| 	err = s.deleteTicket(s.GetTicket(loginId))
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	// create new ticket
 | |
| 	ticket, err := s.CreateTicket(loginId, client)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	// return redirect + "?" + s.paramName.Ticket + "=" + ticket, nil
 | |
| 	return util.AddQuery(s.encodeBackParam(redirect), s.paramName.Ticket, ticket), nil
 | |
| }
 | |
| 
 | |
| // encodeBackParam find back param from url, and encode back param.
 | |
| func (s *SsoEnforcer) encodeBackParam(url string) string {
 | |
| 	// get back location
 | |
| 	index := strings.Index(url, "?"+s.paramName.Back+"=")
 | |
| 	if index == -1 {
 | |
| 		index = strings.Index(url, "&"+s.paramName.Back+"=")
 | |
| 		if index == -1 {
 | |
| 			return url
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// encode
 | |
| 	length := len(s.paramName.Back) + 2
 | |
| 	back := url[index+length:]
 | |
| 	back = util.Encode(back)
 | |
| 
 | |
| 	// update back
 | |
| 	url = url[:index+length] + back
 | |
| 	return url
 | |
| }
 | |
| 
 | |
| // buildCheckTicketUrl build to check ticket.
 | |
| func (s *SsoEnforcer) buildCheckTicketUrl(ticket string, ssoLogoutCallUrl string) (string, error) {
 | |
| 	if ticket == "" {
 | |
| 		return "", errors.New("buildCheckTicketUrl() ticket can not be nil")
 | |
| 	}
 | |
| 	checkTicketUrl := s.config.SpliceCheckTicketUrl()
 | |
| 	client := s.config.Client
 | |
| 	paramMap := make(map[string]string)
 | |
| 	if client != "" {
 | |
| 		paramMap[s.paramName.Client] = client
 | |
| 	}
 | |
| 	paramMap[s.paramName.Ticket] = ticket
 | |
| 	if ssoLogoutCallUrl != "" {
 | |
| 		paramMap[s.paramName.SsoLogoutCall] = ssoLogoutCallUrl
 | |
| 	}
 | |
| 
 | |
| 	return util.AddQueryMap(checkTicketUrl, paramMap), nil
 | |
| }
 | |
| 
 | |
| // buildSloUrl build single-logout url.
 | |
| func (s *SsoEnforcer) buildSloUrl(loginId string) (string, error) {
 | |
| 	sloUrl := s.config.SpliceSloUrl()
 | |
| 	url, err := s.joinLoginIdAndSign(sloUrl, loginId)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	return url, nil
 | |
| }
 | |
| 
 | |
| // buildGetDataUrl build getData url with sign,timestamp,nonce.
 | |
| func (s *SsoEnforcer) buildGetDataUrl(paramMap map[string]string) (string, error) {
 | |
| 	getDataUrl := s.config.SpliceGetDataUrl()
 | |
| 	return s.buildCustomPathUrl(getDataUrl, paramMap)
 | |
| }
 | |
| 
 | |
| // buildCustomPathUrl add paramMap to path.
 | |
| func (s *SsoEnforcer) buildCustomPathUrl(path string, paramMap map[string]string) (string, error) {
 | |
| 	u := path
 | |
| 
 | |
| 	if !strings.HasPrefix(u, "http") {
 | |
| 		serverUrl := s.config.ServerUrl
 | |
| 		if serverUrl == "" {
 | |
| 			return "", errors.New("please set sso serverUrl")
 | |
| 		}
 | |
| 		u = util.SpliceUrl(serverUrl, path)
 | |
| 	}
 | |
| 	// sign map
 | |
| 	timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
 | |
| 	paramMap[s.paramName.TimeStamp] = timestamp
 | |
| 	nonce, err := util.GenerateRandomString32()
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	paramMap[s.paramName.Nonce] = nonce
 | |
| 	// create sign
 | |
| 	sign, err := s.createSign(timestamp, nonce)
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 	paramMap[s.paramName.Sign] = sign
 | |
| 	finalUrl := util.AddQueryMap(u, paramMap)
 | |
| 	return finalUrl, nil
 | |
| }
 | |
| 
 | |
| // request send http request and use json.Unmarshal to converted to *model.Result.
 | |
| func (s *SsoEnforcer) request(url string) (*model.Result, error) {
 | |
| 	resp, err := s.config.SendHttp(url)
 | |
| 
 | |
| 	log.Printf("http request response: %s", resp)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	result := &model.Result{}
 | |
| 	err = json.Unmarshal([]byte(resp), result)
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 	return result, nil
 | |
| }
 | |
| 
 | |
| func (s *SsoEnforcer) GetAllowUrl() string {
 | |
| 	return s.config.AllowUrl
 | |
| }
 | |
| 
 | |
| // saveTicket save ticket-id+client.
 | |
| func (s *SsoEnforcer) saveTicket(ticket string, loginId string, client string) error {
 | |
| 	value := loginId
 | |
| 	if client != "" {
 | |
| 		value += "," + client
 | |
| 	}
 | |
| 	ticketTimeout := s.config.TicketTimeout
 | |
| 	return s.enforcer.GetAdapter().SetStr(s.spliceTicketSaveKey(ticket), value, ticketTimeout)
 | |
| }
 | |
| 
 | |
| // saveTicketIndex save id-ticket.
 | |
| func (s *SsoEnforcer) saveTicketIndex(ticket string, id string) error {
 | |
| 	ticketTimeout := s.config.TicketTimeout
 | |
| 	return s.enforcer.GetAdapter().SetStr(s.spliceTicketIndexKey(id), ticket, ticketTimeout)
 | |
| }
 | |
| 
 | |
| // spliceTicketSaveKey splice ticket-id key.
 | |
| func (s *SsoEnforcer) spliceTicketSaveKey(ticket string) string {
 | |
| 	return s.enforcer.GetTokenConfig().TokenName + ":ticket:" + ticket
 | |
| }
 | |
| 
 | |
| // spliceTicketIndexKey splice id-ticket key.
 | |
| func (s *SsoEnforcer) spliceTicketIndexKey(id string) string {
 | |
| 	return s.enforcer.GetTokenConfig().TokenName + ":id-ticket:" + id
 | |
| }
 | |
| 
 | |
| func (s *SsoEnforcer) deleteTicket(ticket string) error {
 | |
| 	if ticket == "" {
 | |
| 		return nil
 | |
| 	}
 | |
| 	return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketSaveKey(ticket))
 | |
| }
 | |
| 
 | |
| func (s *SsoEnforcer) deleteTicketIndex(id string) error {
 | |
| 	if id == "" {
 | |
| 		return nil
 | |
| 	}
 | |
| 	return s.enforcer.GetAdapter().DeleteStr(s.spliceTicketIndexKey(id))
 | |
| }
 | |
| 
 | |
| // checkTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
 | |
| func (s *SsoEnforcer) checkTimeStamp(timestamp string) error {
 | |
| 	parseInt, err := strconv.ParseInt(timestamp, 10, 64)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if !s.isValidTimeStamp(parseInt) {
 | |
| 		return errors.New("timestamp is out of allowed range: " + timestamp)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // isValidTimeStamp determine whether the gap between the timestamp and the current timestamp is within the allowable range.
 | |
| func (s *SsoEnforcer) isValidTimeStamp(timestamp int64) bool {
 | |
| 	allowDisparity := s.signConfig.TimeStampDisparity
 | |
| 	nowDisparity := time.Now().UnixMilli() - timestamp
 | |
| 
 | |
| 	return allowDisparity == 1 || nowDisparity <= allowDisparity
 | |
| }
 | |
| 
 | |
| // checkNonce the same nonce can only be verified once.
 | |
| func (s *SsoEnforcer) checkNonce(nonce string) error {
 | |
| 	if nonce == "" {
 | |
| 		return errors.New("nonce is nil")
 | |
| 	}
 | |
| 	// if nonce exists in adapter
 | |
| 	if !s.isValidNonce(nonce) {
 | |
| 		return errors.New("the nonce has been used: " + nonce)
 | |
| 	}
 | |
| 	// set nonce after nonce
 | |
| 	err := s.enforcer.GetAdapter().SetStr(s.spliceNonceSaveKey(nonce), nonce, s.signConfig.GetSaveNonceExpire()*2+2)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // isValidNonce determine random string, if not exist in adapter, return true.
 | |
| func (s *SsoEnforcer) isValidNonce(nonce string) bool {
 | |
| 	if nonce == "" {
 | |
| 		return false
 | |
| 	}
 | |
| 	key := s.spliceNonceSaveKey(nonce)
 | |
| 	// if not exist in adapter, return true.
 | |
| 	return s.enforcer.GetAdapter().GetStr(key) == ""
 | |
| }
 | |
| 
 | |
| func (s *SsoEnforcer) checkSign(timestamp string, nonce string, sign string) error {
 | |
| 	valid, err := s.isValidSign(timestamp, nonce, sign)
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	if !valid {
 | |
| 		return errors.New("invalid sign: " + sign)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // isValidSign use timestamp,nonce,sign to createSign to compare.
 | |
| func (s *SsoEnforcer) isValidSign(timestamp string, nonce string, sign string) (bool, error) {
 | |
| 	recreateSign, err := s.createSign(timestamp, nonce)
 | |
| 	if err != nil {
 | |
| 		return false, err
 | |
| 	}
 | |
| 	return recreateSign == sign, nil
 | |
| }
 | |
| 
 | |
| // createSign use util.MD5() to generate str.
 | |
| func (s *SsoEnforcer) createSign(timestamp string, nonce string) (string, error) {
 | |
| 	secretKey := s.signConfig.SecretKey
 | |
| 	if secretKey == "" {
 | |
| 		return "", errors.New("please check SignConfig.SecretKey, SecretKey can not be nil")
 | |
| 	}
 | |
| 	str := s.paramName.Nonce + "=" + nonce + "&" + s.paramName.TimeStamp + "=" + timestamp + "&" + s.paramName.SecretKet + "=" + secretKey
 | |
| 
 | |
| 	return util.MD5(str), nil
 | |
| }
 | |
| 
 | |
| // spliceNonceSaveKey splice nonce store key.
 | |
| func (s *SsoEnforcer) spliceNonceSaveKey(nonce string) string {
 | |
| 	return s.enforcer.GetTokenConfig().TokenName + ":sign:nonce:" + nonce
 | |
| }
 | 
