mirror of
https://github.com/eryajf/chatgpt-dingtalk.git
synced 2025-11-01 03:52:32 +08:00
feat: 支持上传图片到钉钉平台,在图片生成流程中使用钉钉的图片 CDN 能力 (#225)
This commit is contained in:
10
README.md
10
README.md
@@ -225,6 +225,7 @@ $ docker run -itd --name chatgpt -p 8090:8090 \
|
||||
-e SENSITIVE_WORDS="aa,bb" \
|
||||
-e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \
|
||||
-e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \
|
||||
-e DINGTALK_CREDENTIALS="your_client_id1:secret1,your_client_id2:secret2" \
|
||||
-e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/)
|
||||
,觉得不错你可以来波素质三连." \
|
||||
--restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
|
||||
@@ -541,6 +542,15 @@ azure_resource_name: "xxxx"
|
||||
azure_deployment_name: "xxxx"
|
||||
azure_openai_token: "xxxx"
|
||||
|
||||
# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
|
||||
# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
|
||||
# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
|
||||
# 建议采用 credentials 代替 app_secrets 配置项,以获得钉钉 OpenAPI 访问能力
|
||||
credentials:
|
||||
-
|
||||
client_id: "put-your-client-id-here"
|
||||
client_secret: "put-your-client-secret-here"
|
||||
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
@@ -59,3 +59,10 @@ azure_resource_name: "xxxx"
|
||||
azure_deployment_name: "xxxx"
|
||||
azure_openai_token: "xxxx"
|
||||
|
||||
# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
|
||||
# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
|
||||
# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
|
||||
#credentials:
|
||||
# -
|
||||
# client_id: "put-your-client-id-here"
|
||||
# client_secret: "put-your-client-secret-here"
|
||||
|
||||
@@ -14,6 +14,11 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
type Credential struct {
|
||||
ClientID string `yaml:"client_id"`
|
||||
ClientSecret string `yaml:"client_secret"`
|
||||
}
|
||||
|
||||
// Configuration 项目配置
|
||||
type Configuration struct {
|
||||
// 日志级别,info或者debug
|
||||
@@ -62,6 +67,8 @@ type Configuration struct {
|
||||
AzureResourceName string `yaml:"azure_resource_name"`
|
||||
AzureDeploymentName string `yaml:"azure_deployment_name"`
|
||||
AzureOpenAIToken string `yaml:"azure_openai_token"`
|
||||
// 钉钉应用鉴权凭据
|
||||
Credentials []Credential `yaml:"credentials"`
|
||||
}
|
||||
|
||||
var config *Configuration
|
||||
@@ -190,6 +197,18 @@ func LoadConfig() *Configuration {
|
||||
if azureOpenaiToken != "" {
|
||||
config.AzureOpenAIToken = azureOpenaiToken
|
||||
}
|
||||
credentials := os.Getenv("DINGTALK_CREDENTIALS")
|
||||
if credentials != "" {
|
||||
if config.Credentials == nil {
|
||||
config.Credentials = []Credential{}
|
||||
}
|
||||
for _, idSecret := range strings.Split(credentials, ",") {
|
||||
items := strings.SplitN(idSecret, ":", 2)
|
||||
if len(items) == 2 {
|
||||
config.Credentials = append(config.Credentials, Credential{ClientID: items[0], ClientSecret: items[1]})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ services:
|
||||
AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai"
|
||||
AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai"
|
||||
AZURE_OPENAI_TOKEN: "" # Azure token
|
||||
DINGTALK_CREDENTIALS: "" # 钉钉应用访问凭证,比如 "client_id1:secret1,client_id2:secret2"
|
||||
HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义
|
||||
volumes:
|
||||
- ./data:/app/data
|
||||
|
||||
10
main.go
10
main.go
@@ -33,6 +33,14 @@ func Start() {
|
||||
return
|
||||
}
|
||||
// 先校验回调是否合法
|
||||
clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign"))
|
||||
if !checkOk {
|
||||
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
|
||||
return
|
||||
}
|
||||
// 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI
|
||||
c.Set(public.DingTalkClientIdKeyName, clientId)
|
||||
// 为了兼容存量老用户,暂时保留 public.CheckRequest 方法,将来升级到 Stream 模式后,建议去除该方法,采用上面的 CheckRequestWithCredentials
|
||||
if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" {
|
||||
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
|
||||
return
|
||||
@@ -114,7 +122,7 @@ func Start() {
|
||||
// 除去帮助之外的逻辑分流在这里处理
|
||||
switch {
|
||||
case strings.HasPrefix(msgObj.Text.Content, "#图片"):
|
||||
err := process.ImageGenerate(&msgObj)
|
||||
err := process.ImageGenerate(c, &msgObj)
|
||||
if err != nil {
|
||||
logger.Warning(fmt.Errorf("process request: %v", err))
|
||||
return
|
||||
|
||||
@@ -2,8 +2,12 @@ package chatgpt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
|
||||
"github.com/pandodao/tokenizer-go"
|
||||
"image/png"
|
||||
"os"
|
||||
@@ -218,7 +222,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
|
||||
return resp.Choices[0].Text, nil
|
||||
}
|
||||
}
|
||||
func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
|
||||
func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) {
|
||||
model := public.Config.Model
|
||||
if model == openai.GPT3Dot5Turbo0301 ||
|
||||
model == openai.GPT3Dot5Turbo ||
|
||||
@@ -247,6 +251,13 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
|
||||
}
|
||||
|
||||
imageName := time.Now().Format("20060102-150405") + ".png"
|
||||
clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string)
|
||||
client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId)
|
||||
mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId))
|
||||
if client != nil {
|
||||
mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng)
|
||||
}
|
||||
|
||||
err = os.MkdirAll("data/images", 0755)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -260,9 +271,12 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
|
||||
if err := png.Encode(file, imgData); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if uploadErr == nil {
|
||||
return mediaResult.MediaID, nil
|
||||
} else {
|
||||
return public.Config.ServiceURL + "/images/" + imageName, nil
|
||||
}
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/avast/retry-go"
|
||||
@@ -58,7 +59,7 @@ func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error
|
||||
}
|
||||
|
||||
// ImageQa 生成图片
|
||||
func ImageQa(question, userId string) (answer string, err error) {
|
||||
func ImageQa(ctx context.Context, question, userId string) (answer string, err error) {
|
||||
chat := New(userId)
|
||||
defer chat.Close()
|
||||
// 定义一个重试策略
|
||||
@@ -70,7 +71,7 @@ func ImageQa(question, userId string) (answer string, err error) {
|
||||
// 使用重试策略进行重试
|
||||
err = retry.Do(
|
||||
func() error {
|
||||
answer, err = chat.GenreateImage(question)
|
||||
answer, err = chat.GenreateImage(ctx, question)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
213
pkg/dingbot/client.go
Normal file
213
pkg/dingbot/client.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package dingbot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
url2 "net/url"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
|
||||
const (
|
||||
MediaTypeImage string = "image"
|
||||
MediaTypeVoice string = "voice"
|
||||
MediaTypeVideo string = "video"
|
||||
MediaTypeFile string = "file"
|
||||
)
|
||||
const (
|
||||
MimeTypeImagePng string = "image/png"
|
||||
)
|
||||
|
||||
type MediaUploadResult struct {
|
||||
ErrorCode int64 `json:"errcode"`
|
||||
ErrorMessage string `json:"errmsg"`
|
||||
MediaID string `json:"media_id"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
type OAuthTokenResult struct {
|
||||
ErrorCode int `json:"errcode"`
|
||||
ErrorMessage string `json:"errmsg"`
|
||||
AccessToken string `json:"access_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type DingTalkClientInterface interface {
|
||||
GetAccessToken() (string, error)
|
||||
UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error)
|
||||
}
|
||||
|
||||
type DingTalkClientManagerInterface interface {
|
||||
GetClientByOAuthClientID(clientId string) DingTalkClientInterface
|
||||
}
|
||||
|
||||
type DingTalkClient struct {
|
||||
Credential config.Credential
|
||||
AccessToken string
|
||||
expireAt int64
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
type DingTalkClientManager struct {
|
||||
Credentials []config.Credential
|
||||
Clients map[string]*DingTalkClient
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewDingTalkClient(credential config.Credential) *DingTalkClient {
|
||||
return &DingTalkClient{
|
||||
Credential: credential,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDingTalkClientManager(conf *config.Configuration) *DingTalkClientManager {
|
||||
clients := make(map[string]*DingTalkClient)
|
||||
|
||||
if conf != nil && conf.Credentials != nil {
|
||||
for _, credential := range conf.Credentials {
|
||||
clients[credential.ClientID] = NewDingTalkClient(credential)
|
||||
}
|
||||
}
|
||||
return &DingTalkClientManager{
|
||||
Credentials: conf.Credentials,
|
||||
Clients: clients,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DingTalkClientManager) GetClientByOAuthClientID(clientId string) DingTalkClientInterface {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if client, ok := m.Clients[clientId]; ok {
|
||||
return client
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) GetAccessToken() (string, error) {
|
||||
accessToken := ""
|
||||
{
|
||||
// 先查询缓存
|
||||
c.mutex.Lock()
|
||||
now := time.Now().Unix()
|
||||
if c.expireAt > 0 && c.AccessToken != "" && (now+60) < c.expireAt {
|
||||
// 预留一分钟有效期避免在Token过期的临界点调用接口出现401错误
|
||||
accessToken = c.AccessToken
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
if accessToken != "" {
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
tokenResult, err := c.getAccessTokenFromDingTalk()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
{
|
||||
// 更新缓存
|
||||
c.mutex.Lock()
|
||||
c.AccessToken = tokenResult.AccessToken
|
||||
c.expireAt = time.Now().Unix() + int64(tokenResult.ExpiresIn)
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
return tokenResult.AccessToken, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) {
|
||||
// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
|
||||
accessToken, err := c.GetAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(accessToken) == 0 {
|
||||
return nil, errors.New("empty access token")
|
||||
}
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
part, err := writer.CreateFormFile("media", filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = part.Write(content)
|
||||
writer.WriteField("type", mediaType)
|
||||
err = writer.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create a new HTTP request to upload the media file
|
||||
url := fmt.Sprintf("https://oapi.dingtalk.com/media/upload?access_token=%s", url2.QueryEscape(accessToken))
|
||||
req, err := http.NewRequest("POST", url, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
// Send the HTTP request and parse the response
|
||||
client := &http.Client{
|
||||
Timeout: time.Second * 60,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// Parse the response body as JSON and extract the media ID
|
||||
media := &MediaUploadResult{}
|
||||
bodyBytes, err := io.ReadAll(res.Body)
|
||||
json.Unmarshal(bodyBytes, media)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if media.ErrorCode != 0 {
|
||||
return nil, errors.New(media.ErrorMessage)
|
||||
}
|
||||
return media, nil
|
||||
}
|
||||
|
||||
func (c *DingTalkClient) getAccessTokenFromDingTalk() (*OAuthTokenResult, error) {
|
||||
// OpenAPI doc: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
|
||||
apiUrl := "https://oapi.dingtalk.com/gettoken"
|
||||
queryParams := url2.Values{}
|
||||
queryParams.Add("appkey", c.Credential.ClientID)
|
||||
queryParams.Add("appsecret", c.Credential.ClientSecret)
|
||||
|
||||
// Create a new HTTP request to get the AccessToken
|
||||
req, err := http.NewRequest("GET", apiUrl+"?"+queryParams.Encode(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send the HTTP request and parse the response body as JSON
|
||||
client := http.Client{
|
||||
Timeout: time.Second * 60,
|
||||
}
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenResult := &OAuthTokenResult{}
|
||||
err = json.Unmarshal(body, tokenResult)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tokenResult.ErrorCode != 0 {
|
||||
return nil, errors.New(tokenResult.ErrorMessage)
|
||||
}
|
||||
return tokenResult, nil
|
||||
}
|
||||
53
pkg/dingbot/client_test.go
Normal file
53
pkg/dingbot/client_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package dingbot
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/png"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUploadMedia_Pass_WithValidConfig(t *testing.T) {
|
||||
// 设置了钉钉 ClientID 和 ClientSecret 的环境变量才执行以下测试,用于快速验证钉钉图片上传能力
|
||||
clientId, clientSecret := os.Getenv("DINGTALK_CLIENT_ID_FOR_TEST"), os.Getenv("DINGTALK_CLIENT_SECRET_FOR_TEST")
|
||||
if len(clientId) <= 0 || len(clientSecret) <= 0 {
|
||||
return
|
||||
}
|
||||
credentials := []config.Credential{
|
||||
config.Credential{
|
||||
ClientID: clientId,
|
||||
ClientSecret: clientSecret,
|
||||
},
|
||||
}
|
||||
client := NewDingTalkClientManager(&config.Configuration{Credentials: credentials}).GetClientByOAuthClientID(clientId)
|
||||
var imageContent []byte
|
||||
{
|
||||
// 生成一张用于测试的图片
|
||||
img := image.NewRGBA(image.Rect(0, 0, 200, 100))
|
||||
blue := color.RGBA{0, 0, 255, 255}
|
||||
for x := 0; x < img.Bounds().Dx(); x++ {
|
||||
for y := 0; y < img.Bounds().Dy(); y++ {
|
||||
img.Set(x, y, blue)
|
||||
}
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
err := png.Encode(buf, img)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// get the byte array from the buffer
|
||||
imageContent = buf.Bytes()
|
||||
}
|
||||
result, err := client.UploadMedia(imageContent, "filename.png", "image", "image/png")
|
||||
if err != nil {
|
||||
t.Errorf("upload media failed, err=%s", err.Error())
|
||||
return
|
||||
}
|
||||
if result.MediaID == "" {
|
||||
t.Errorf("upload media failed, empty media id")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/eryajf/chatgpt-dingtalk/public"
|
||||
"strings"
|
||||
@@ -12,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
// ImageGenerate openai生成图片
|
||||
func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
|
||||
func ImageGenerate(ctx context.Context, rmsg *dingbot.ReceiveMsg) error {
|
||||
if public.Config.AzureOn {
|
||||
_, err := rmsg.ReplyToDingtalk(string(dingbot.
|
||||
MARKDOWN), "azure 模式下暂不支持图片创作功能")
|
||||
@@ -32,7 +33,7 @@ func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
|
||||
if err != nil {
|
||||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||||
}
|
||||
reply, err := chatgpt.ImageQa(rmsg.Text.Content, rmsg.GetSenderIdentifier())
|
||||
reply, err := chatgpt.ImageQa(ctx, rmsg.Text.Content, rmsg.GetSenderIdentifier())
|
||||
if err != nil {
|
||||
logger.Info(fmt.Errorf("gpt request error: %v", err))
|
||||
_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err))
|
||||
|
||||
@@ -4,12 +4,16 @@ import (
|
||||
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||
"github.com/eryajf/chatgpt-dingtalk/pkg/cache"
|
||||
"github.com/eryajf/chatgpt-dingtalk/pkg/db"
|
||||
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
var UserService cache.UserServiceInterface
|
||||
var Config *config.Configuration
|
||||
var Prompt *[]config.Prompt
|
||||
var DingTalkClientManager dingbot.DingTalkClientManagerInterface
|
||||
|
||||
const DingTalkClientIdKeyName = "DingTalkClientId"
|
||||
|
||||
func InitSvc() {
|
||||
// 加载配置
|
||||
@@ -18,6 +22,8 @@ func InitSvc() {
|
||||
Prompt = config.LoadPrompt()
|
||||
// 初始化缓存
|
||||
UserService = cache.NewUserService()
|
||||
// 初始化钉钉开放平台的客户端,用于访问上传图片等能力
|
||||
DingTalkClientManager = dingbot.NewDingTalkClientManager(Config)
|
||||
// 初始化数据库
|
||||
db.InitDB()
|
||||
// 暂时不在初始化时获取余额
|
||||
|
||||
@@ -124,6 +124,23 @@ func GetReadTime(t time.Time) string {
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
func CheckRequestWithCredentials(ts, sg string) (clientId string, pass bool) {
|
||||
clientId, pass = "", false
|
||||
credentials := Config.Credentials
|
||||
if credentials == nil || len(credentials) == 0 {
|
||||
return "", true
|
||||
}
|
||||
for _, credential := range Config.Credentials {
|
||||
stringToSign := fmt.Sprintf("%s\n%s", ts, credential.ClientSecret)
|
||||
mac := hmac.New(sha256.New, []byte(credential.ClientSecret))
|
||||
_, _ = mac.Write([]byte(stringToSign))
|
||||
if base64.StdEncoding.EncodeToString(mac.Sum(nil)) == sg {
|
||||
return credential.ClientID, true
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func CheckRequest(ts, sg string) bool {
|
||||
appSecrets := Config.AppSecrets
|
||||
// 如果没有指定或者outgoing类型机器人下使用,则默认不做校验
|
||||
|
||||
76
public/tools_test.go
Normal file
76
public/tools_test.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckRequestWithCredentials_Pass_WithNilConfig(t *testing.T) {
|
||||
Config = &config.Configuration{
|
||||
Credentials: nil,
|
||||
}
|
||||
clientId, pass := CheckRequestWithCredentials("ts", "sg")
|
||||
if !pass {
|
||||
t.Errorf("pass should be true, but false")
|
||||
return
|
||||
}
|
||||
if len(clientId) > 0 {
|
||||
t.Errorf("client id should be empty")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRequestWithCredentials_Pass_WithEmptyConfig(t *testing.T) {
|
||||
Config = &config.Configuration{
|
||||
Credentials: []config.Credential{},
|
||||
}
|
||||
clientId, pass := CheckRequestWithCredentials("ts", "sg")
|
||||
if !pass {
|
||||
t.Errorf("pass should be true, but false")
|
||||
return
|
||||
}
|
||||
if len(clientId) > 0 {
|
||||
t.Errorf("client id should be empty")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRequestWithCredentials_Pass_WithValidConfig(t *testing.T) {
|
||||
Config = &config.Configuration{
|
||||
Credentials: []config.Credential{
|
||||
config.Credential{
|
||||
ClientID: "client-id-for-test",
|
||||
ClientSecret: "client-secret-for-test",
|
||||
},
|
||||
},
|
||||
}
|
||||
clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=")
|
||||
if !pass {
|
||||
t.Errorf("pass should be true, but false")
|
||||
return
|
||||
}
|
||||
if clientId != "client-id-for-test" {
|
||||
t.Errorf("client id should be \"%s\", but \"%s\"", "client-id-for-test", clientId)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRequestWithCredentials_Failed_WithInvalidConfig(t *testing.T) {
|
||||
Config = &config.Configuration{
|
||||
Credentials: []config.Credential{
|
||||
config.Credential{
|
||||
ClientID: "client-id-for-test",
|
||||
ClientSecret: "invalid-client-secret-for-test",
|
||||
},
|
||||
},
|
||||
}
|
||||
clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=")
|
||||
if pass {
|
||||
t.Errorf("pass should be false, but true")
|
||||
return
|
||||
}
|
||||
if clientId != "" {
|
||||
t.Errorf("client id should be empty")
|
||||
return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user