mirror of
https://github.com/eryajf/chatgpt-dingtalk.git
synced 2025-10-30 19:16:19 +08:00
feat: 支持上传图片到钉钉平台,在图片生成流程中使用钉钉的图片 CDN 能力 (#225)
This commit is contained in:
10
README.md
10
README.md
@@ -225,6 +225,7 @@ $ docker run -itd --name chatgpt -p 8090:8090 \
|
|||||||
-e SENSITIVE_WORDS="aa,bb" \
|
-e SENSITIVE_WORDS="aa,bb" \
|
||||||
-e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \
|
-e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \
|
||||||
-e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \
|
-e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \
|
||||||
|
-e DINGTALK_CREDENTIALS="your_client_id1:secret1,your_client_id2:secret2" \
|
||||||
-e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/)
|
-e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/)
|
||||||
,觉得不错你可以来波素质三连." \
|
,觉得不错你可以来波素质三连." \
|
||||||
--restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
|
--restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest
|
||||||
@@ -541,6 +542,15 @@ azure_resource_name: "xxxx"
|
|||||||
azure_deployment_name: "xxxx"
|
azure_deployment_name: "xxxx"
|
||||||
azure_openai_token: "xxxx"
|
azure_openai_token: "xxxx"
|
||||||
|
|
||||||
|
# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
|
||||||
|
# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
|
||||||
|
# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
|
||||||
|
# 建议采用 credentials 代替 app_secrets 配置项,以获得钉钉 OpenAPI 访问能力
|
||||||
|
credentials:
|
||||||
|
-
|
||||||
|
client_id: "put-your-client-id-here"
|
||||||
|
client_secret: "put-your-client-secret-here"
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 常见问题
|
## 常见问题
|
||||||
|
|||||||
@@ -59,3 +59,10 @@ azure_resource_name: "xxxx"
|
|||||||
azure_deployment_name: "xxxx"
|
azure_deployment_name: "xxxx"
|
||||||
azure_openai_token: "xxxx"
|
azure_openai_token: "xxxx"
|
||||||
|
|
||||||
|
# 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息
|
||||||
|
# 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署
|
||||||
|
# client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret
|
||||||
|
#credentials:
|
||||||
|
# -
|
||||||
|
# client_id: "put-your-client-id-here"
|
||||||
|
# client_secret: "put-your-client-secret-here"
|
||||||
|
|||||||
@@ -14,6 +14,11 @@ import (
|
|||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Credential struct {
|
||||||
|
ClientID string `yaml:"client_id"`
|
||||||
|
ClientSecret string `yaml:"client_secret"`
|
||||||
|
}
|
||||||
|
|
||||||
// Configuration 项目配置
|
// Configuration 项目配置
|
||||||
type Configuration struct {
|
type Configuration struct {
|
||||||
// 日志级别,info或者debug
|
// 日志级别,info或者debug
|
||||||
@@ -62,6 +67,8 @@ type Configuration struct {
|
|||||||
AzureResourceName string `yaml:"azure_resource_name"`
|
AzureResourceName string `yaml:"azure_resource_name"`
|
||||||
AzureDeploymentName string `yaml:"azure_deployment_name"`
|
AzureDeploymentName string `yaml:"azure_deployment_name"`
|
||||||
AzureOpenAIToken string `yaml:"azure_openai_token"`
|
AzureOpenAIToken string `yaml:"azure_openai_token"`
|
||||||
|
// 钉钉应用鉴权凭据
|
||||||
|
Credentials []Credential `yaml:"credentials"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var config *Configuration
|
var config *Configuration
|
||||||
@@ -190,6 +197,18 @@ func LoadConfig() *Configuration {
|
|||||||
if azureOpenaiToken != "" {
|
if azureOpenaiToken != "" {
|
||||||
config.AzureOpenAIToken = azureOpenaiToken
|
config.AzureOpenAIToken = azureOpenaiToken
|
||||||
}
|
}
|
||||||
|
credentials := os.Getenv("DINGTALK_CREDENTIALS")
|
||||||
|
if credentials != "" {
|
||||||
|
if config.Credentials == nil {
|
||||||
|
config.Credentials = []Credential{}
|
||||||
|
}
|
||||||
|
for _, idSecret := range strings.Split(credentials, ",") {
|
||||||
|
items := strings.SplitN(idSecret, ":", 2)
|
||||||
|
if len(items) == 2 {
|
||||||
|
config.Credentials = append(config.Credentials, Credential{ClientID: items[0], ClientSecret: items[1]})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ services:
|
|||||||
AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai"
|
AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai"
|
||||||
AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai"
|
AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai"
|
||||||
AZURE_OPENAI_TOKEN: "" # Azure token
|
AZURE_OPENAI_TOKEN: "" # Azure token
|
||||||
|
DINGTALK_CREDENTIALS: "" # 钉钉应用访问凭证,比如 "client_id1:secret1,client_id2:secret2"
|
||||||
HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义
|
HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义
|
||||||
volumes:
|
volumes:
|
||||||
- ./data:/app/data
|
- ./data:/app/data
|
||||||
|
|||||||
10
main.go
10
main.go
@@ -33,6 +33,14 @@ func Start() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
// 先校验回调是否合法
|
// 先校验回调是否合法
|
||||||
|
clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign"))
|
||||||
|
if !checkOk {
|
||||||
|
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI
|
||||||
|
c.Set(public.DingTalkClientIdKeyName, clientId)
|
||||||
|
// 为了兼容存量老用户,暂时保留 public.CheckRequest 方法,将来升级到 Stream 模式后,建议去除该方法,采用上面的 CheckRequestWithCredentials
|
||||||
if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" {
|
if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" {
|
||||||
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
|
logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!")
|
||||||
return
|
return
|
||||||
@@ -114,7 +122,7 @@ func Start() {
|
|||||||
// 除去帮助之外的逻辑分流在这里处理
|
// 除去帮助之外的逻辑分流在这里处理
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(msgObj.Text.Content, "#图片"):
|
case strings.HasPrefix(msgObj.Text.Content, "#图片"):
|
||||||
err := process.ImageGenerate(&msgObj)
|
err := process.ImageGenerate(c, &msgObj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Warning(fmt.Errorf("process request: %v", err))
|
logger.Warning(fmt.Errorf("process request: %v", err))
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,8 +2,12 @@ package chatgpt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
|
||||||
"github.com/pandodao/tokenizer-go"
|
"github.com/pandodao/tokenizer-go"
|
||||||
"image/png"
|
"image/png"
|
||||||
"os"
|
"os"
|
||||||
@@ -218,7 +222,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) {
|
|||||||
return resp.Choices[0].Text, nil
|
return resp.Choices[0].Text, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
|
func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) {
|
||||||
model := public.Config.Model
|
model := public.Config.Model
|
||||||
if model == openai.GPT3Dot5Turbo0301 ||
|
if model == openai.GPT3Dot5Turbo0301 ||
|
||||||
model == openai.GPT3Dot5Turbo ||
|
model == openai.GPT3Dot5Turbo ||
|
||||||
@@ -247,6 +251,13 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
imageName := time.Now().Format("20060102-150405") + ".png"
|
imageName := time.Now().Format("20060102-150405") + ".png"
|
||||||
|
clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string)
|
||||||
|
client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId)
|
||||||
|
mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId))
|
||||||
|
if client != nil {
|
||||||
|
mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng)
|
||||||
|
}
|
||||||
|
|
||||||
err = os.MkdirAll("data/images", 0755)
|
err = os.MkdirAll("data/images", 0755)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -260,8 +271,11 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) {
|
|||||||
if err := png.Encode(file, imgData); err != nil {
|
if err := png.Encode(file, imgData); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
if uploadErr == nil {
|
||||||
return public.Config.ServiceURL + "/images/" + imageName, nil
|
return mediaResult.MediaID, nil
|
||||||
|
} else {
|
||||||
|
return public.Config.ServiceURL + "/images/" + imageName, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package chatgpt
|
package chatgpt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/avast/retry-go"
|
"github.com/avast/retry-go"
|
||||||
@@ -58,7 +59,7 @@ func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ImageQa 生成图片
|
// ImageQa 生成图片
|
||||||
func ImageQa(question, userId string) (answer string, err error) {
|
func ImageQa(ctx context.Context, question, userId string) (answer string, err error) {
|
||||||
chat := New(userId)
|
chat := New(userId)
|
||||||
defer chat.Close()
|
defer chat.Close()
|
||||||
// 定义一个重试策略
|
// 定义一个重试策略
|
||||||
@@ -70,7 +71,7 @@ func ImageQa(question, userId string) (answer string, err error) {
|
|||||||
// 使用重试策略进行重试
|
// 使用重试策略进行重试
|
||||||
err = retry.Do(
|
err = retry.Do(
|
||||||
func() error {
|
func() error {
|
||||||
answer, err = chat.GenreateImage(question)
|
answer, err = chat.GenreateImage(ctx, question)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
213
pkg/dingbot/client.go
Normal file
213
pkg/dingbot/client.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package dingbot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
url2 "net/url"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
|
||||||
|
const (
|
||||||
|
MediaTypeImage string = "image"
|
||||||
|
MediaTypeVoice string = "voice"
|
||||||
|
MediaTypeVideo string = "video"
|
||||||
|
MediaTypeFile string = "file"
|
||||||
|
)
|
||||||
|
const (
|
||||||
|
MimeTypeImagePng string = "image/png"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MediaUploadResult struct {
|
||||||
|
ErrorCode int64 `json:"errcode"`
|
||||||
|
ErrorMessage string `json:"errmsg"`
|
||||||
|
MediaID string `json:"media_id"`
|
||||||
|
CreatedAt int64 `json:"created_at"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthTokenResult struct {
|
||||||
|
ErrorCode int `json:"errcode"`
|
||||||
|
ErrorMessage string `json:"errmsg"`
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type DingTalkClientInterface interface {
|
||||||
|
GetAccessToken() (string, error)
|
||||||
|
UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type DingTalkClientManagerInterface interface {
|
||||||
|
GetClientByOAuthClientID(clientId string) DingTalkClientInterface
|
||||||
|
}
|
||||||
|
|
||||||
|
type DingTalkClient struct {
|
||||||
|
Credential config.Credential
|
||||||
|
AccessToken string
|
||||||
|
expireAt int64
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type DingTalkClientManager struct {
|
||||||
|
Credentials []config.Credential
|
||||||
|
Clients map[string]*DingTalkClient
|
||||||
|
mutex sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDingTalkClient(credential config.Credential) *DingTalkClient {
|
||||||
|
return &DingTalkClient{
|
||||||
|
Credential: credential,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDingTalkClientManager(conf *config.Configuration) *DingTalkClientManager {
|
||||||
|
clients := make(map[string]*DingTalkClient)
|
||||||
|
|
||||||
|
if conf != nil && conf.Credentials != nil {
|
||||||
|
for _, credential := range conf.Credentials {
|
||||||
|
clients[credential.ClientID] = NewDingTalkClient(credential)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &DingTalkClientManager{
|
||||||
|
Credentials: conf.Credentials,
|
||||||
|
Clients: clients,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DingTalkClientManager) GetClientByOAuthClientID(clientId string) DingTalkClientInterface {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
if client, ok := m.Clients[clientId]; ok {
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DingTalkClient) GetAccessToken() (string, error) {
|
||||||
|
accessToken := ""
|
||||||
|
{
|
||||||
|
// 先查询缓存
|
||||||
|
c.mutex.Lock()
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if c.expireAt > 0 && c.AccessToken != "" && (now+60) < c.expireAt {
|
||||||
|
// 预留一分钟有效期避免在Token过期的临界点调用接口出现401错误
|
||||||
|
accessToken = c.AccessToken
|
||||||
|
}
|
||||||
|
c.mutex.Unlock()
|
||||||
|
}
|
||||||
|
if accessToken != "" {
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResult, err := c.getAccessTokenFromDingTalk()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// 更新缓存
|
||||||
|
c.mutex.Lock()
|
||||||
|
c.AccessToken = tokenResult.AccessToken
|
||||||
|
c.expireAt = time.Now().Unix() + int64(tokenResult.ExpiresIn)
|
||||||
|
c.mutex.Unlock()
|
||||||
|
}
|
||||||
|
return tokenResult.AccessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DingTalkClient) UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) {
|
||||||
|
// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files
|
||||||
|
accessToken, err := c.GetAccessToken()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(accessToken) == 0 {
|
||||||
|
return nil, errors.New("empty access token")
|
||||||
|
}
|
||||||
|
body := &bytes.Buffer{}
|
||||||
|
writer := multipart.NewWriter(body)
|
||||||
|
part, err := writer.CreateFormFile("media", filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = part.Write(content)
|
||||||
|
writer.WriteField("type", mediaType)
|
||||||
|
err = writer.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new HTTP request to upload the media file
|
||||||
|
url := fmt.Sprintf("https://oapi.dingtalk.com/media/upload?access_token=%s", url2.QueryEscape(accessToken))
|
||||||
|
req, err := http.NewRequest("POST", url, body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||||
|
|
||||||
|
// Send the HTTP request and parse the response
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: time.Second * 60,
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
// Parse the response body as JSON and extract the media ID
|
||||||
|
media := &MediaUploadResult{}
|
||||||
|
bodyBytes, err := io.ReadAll(res.Body)
|
||||||
|
json.Unmarshal(bodyBytes, media)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if media.ErrorCode != 0 {
|
||||||
|
return nil, errors.New(media.ErrorMessage)
|
||||||
|
}
|
||||||
|
return media, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DingTalkClient) getAccessTokenFromDingTalk() (*OAuthTokenResult, error) {
|
||||||
|
// OpenAPI doc: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token
|
||||||
|
apiUrl := "https://oapi.dingtalk.com/gettoken"
|
||||||
|
queryParams := url2.Values{}
|
||||||
|
queryParams.Add("appkey", c.Credential.ClientID)
|
||||||
|
queryParams.Add("appsecret", c.Credential.ClientSecret)
|
||||||
|
|
||||||
|
// Create a new HTTP request to get the AccessToken
|
||||||
|
req, err := http.NewRequest("GET", apiUrl+"?"+queryParams.Encode(), nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the HTTP request and parse the response body as JSON
|
||||||
|
client := http.Client{
|
||||||
|
Timeout: time.Second * 60,
|
||||||
|
}
|
||||||
|
res, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
body, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tokenResult := &OAuthTokenResult{}
|
||||||
|
err = json.Unmarshal(body, tokenResult)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tokenResult.ErrorCode != 0 {
|
||||||
|
return nil, errors.New(tokenResult.ErrorMessage)
|
||||||
|
}
|
||||||
|
return tokenResult, nil
|
||||||
|
}
|
||||||
53
pkg/dingbot/client_test.go
Normal file
53
pkg/dingbot/client_test.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package dingbot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||||
|
"image"
|
||||||
|
"image/color"
|
||||||
|
"image/png"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUploadMedia_Pass_WithValidConfig(t *testing.T) {
|
||||||
|
// 设置了钉钉 ClientID 和 ClientSecret 的环境变量才执行以下测试,用于快速验证钉钉图片上传能力
|
||||||
|
clientId, clientSecret := os.Getenv("DINGTALK_CLIENT_ID_FOR_TEST"), os.Getenv("DINGTALK_CLIENT_SECRET_FOR_TEST")
|
||||||
|
if len(clientId) <= 0 || len(clientSecret) <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
credentials := []config.Credential{
|
||||||
|
config.Credential{
|
||||||
|
ClientID: clientId,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client := NewDingTalkClientManager(&config.Configuration{Credentials: credentials}).GetClientByOAuthClientID(clientId)
|
||||||
|
var imageContent []byte
|
||||||
|
{
|
||||||
|
// 生成一张用于测试的图片
|
||||||
|
img := image.NewRGBA(image.Rect(0, 0, 200, 100))
|
||||||
|
blue := color.RGBA{0, 0, 255, 255}
|
||||||
|
for x := 0; x < img.Bounds().Dx(); x++ {
|
||||||
|
for y := 0; y < img.Bounds().Dy(); y++ {
|
||||||
|
img.Set(x, y, blue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
err := png.Encode(buf, img)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// get the byte array from the buffer
|
||||||
|
imageContent = buf.Bytes()
|
||||||
|
}
|
||||||
|
result, err := client.UploadMedia(imageContent, "filename.png", "image", "image/png")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("upload media failed, err=%s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if result.MediaID == "" {
|
||||||
|
t.Errorf("upload media failed, empty media id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package process
|
package process
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/eryajf/chatgpt-dingtalk/public"
|
"github.com/eryajf/chatgpt-dingtalk/public"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -12,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ImageGenerate openai生成图片
|
// ImageGenerate openai生成图片
|
||||||
func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
|
func ImageGenerate(ctx context.Context, rmsg *dingbot.ReceiveMsg) error {
|
||||||
if public.Config.AzureOn {
|
if public.Config.AzureOn {
|
||||||
_, err := rmsg.ReplyToDingtalk(string(dingbot.
|
_, err := rmsg.ReplyToDingtalk(string(dingbot.
|
||||||
MARKDOWN), "azure 模式下暂不支持图片创作功能")
|
MARKDOWN), "azure 模式下暂不支持图片创作功能")
|
||||||
@@ -32,7 +33,7 @@ func ImageGenerate(rmsg *dingbot.ReceiveMsg) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
logger.Error("往MySQL新增数据失败,错误信息:", err)
|
||||||
}
|
}
|
||||||
reply, err := chatgpt.ImageQa(rmsg.Text.Content, rmsg.GetSenderIdentifier())
|
reply, err := chatgpt.ImageQa(ctx, rmsg.Text.Content, rmsg.GetSenderIdentifier())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Info(fmt.Errorf("gpt request error: %v", err))
|
logger.Info(fmt.Errorf("gpt request error: %v", err))
|
||||||
_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err))
|
_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err))
|
||||||
|
|||||||
@@ -4,12 +4,16 @@ import (
|
|||||||
"github.com/eryajf/chatgpt-dingtalk/config"
|
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||||
"github.com/eryajf/chatgpt-dingtalk/pkg/cache"
|
"github.com/eryajf/chatgpt-dingtalk/pkg/cache"
|
||||||
"github.com/eryajf/chatgpt-dingtalk/pkg/db"
|
"github.com/eryajf/chatgpt-dingtalk/pkg/db"
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot"
|
||||||
"github.com/sashabaranov/go-openai"
|
"github.com/sashabaranov/go-openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
var UserService cache.UserServiceInterface
|
var UserService cache.UserServiceInterface
|
||||||
var Config *config.Configuration
|
var Config *config.Configuration
|
||||||
var Prompt *[]config.Prompt
|
var Prompt *[]config.Prompt
|
||||||
|
var DingTalkClientManager dingbot.DingTalkClientManagerInterface
|
||||||
|
|
||||||
|
const DingTalkClientIdKeyName = "DingTalkClientId"
|
||||||
|
|
||||||
func InitSvc() {
|
func InitSvc() {
|
||||||
// 加载配置
|
// 加载配置
|
||||||
@@ -18,6 +22,8 @@ func InitSvc() {
|
|||||||
Prompt = config.LoadPrompt()
|
Prompt = config.LoadPrompt()
|
||||||
// 初始化缓存
|
// 初始化缓存
|
||||||
UserService = cache.NewUserService()
|
UserService = cache.NewUserService()
|
||||||
|
// 初始化钉钉开放平台的客户端,用于访问上传图片等能力
|
||||||
|
DingTalkClientManager = dingbot.NewDingTalkClientManager(Config)
|
||||||
// 初始化数据库
|
// 初始化数据库
|
||||||
db.InitDB()
|
db.InitDB()
|
||||||
// 暂时不在初始化时获取余额
|
// 暂时不在初始化时获取余额
|
||||||
|
|||||||
@@ -124,6 +124,23 @@ func GetReadTime(t time.Time) string {
|
|||||||
return t.Format("2006-01-02 15:04:05")
|
return t.Format("2006-01-02 15:04:05")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CheckRequestWithCredentials(ts, sg string) (clientId string, pass bool) {
|
||||||
|
clientId, pass = "", false
|
||||||
|
credentials := Config.Credentials
|
||||||
|
if credentials == nil || len(credentials) == 0 {
|
||||||
|
return "", true
|
||||||
|
}
|
||||||
|
for _, credential := range Config.Credentials {
|
||||||
|
stringToSign := fmt.Sprintf("%s\n%s", ts, credential.ClientSecret)
|
||||||
|
mac := hmac.New(sha256.New, []byte(credential.ClientSecret))
|
||||||
|
_, _ = mac.Write([]byte(stringToSign))
|
||||||
|
if base64.StdEncoding.EncodeToString(mac.Sum(nil)) == sg {
|
||||||
|
return credential.ClientID, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func CheckRequest(ts, sg string) bool {
|
func CheckRequest(ts, sg string) bool {
|
||||||
appSecrets := Config.AppSecrets
|
appSecrets := Config.AppSecrets
|
||||||
// 如果没有指定或者outgoing类型机器人下使用,则默认不做校验
|
// 如果没有指定或者outgoing类型机器人下使用,则默认不做校验
|
||||||
|
|||||||
76
public/tools_test.go
Normal file
76
public/tools_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package public
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/eryajf/chatgpt-dingtalk/config"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCheckRequestWithCredentials_Pass_WithNilConfig(t *testing.T) {
|
||||||
|
Config = &config.Configuration{
|
||||||
|
Credentials: nil,
|
||||||
|
}
|
||||||
|
clientId, pass := CheckRequestWithCredentials("ts", "sg")
|
||||||
|
if !pass {
|
||||||
|
t.Errorf("pass should be true, but false")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(clientId) > 0 {
|
||||||
|
t.Errorf("client id should be empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckRequestWithCredentials_Pass_WithEmptyConfig(t *testing.T) {
|
||||||
|
Config = &config.Configuration{
|
||||||
|
Credentials: []config.Credential{},
|
||||||
|
}
|
||||||
|
clientId, pass := CheckRequestWithCredentials("ts", "sg")
|
||||||
|
if !pass {
|
||||||
|
t.Errorf("pass should be true, but false")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(clientId) > 0 {
|
||||||
|
t.Errorf("client id should be empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckRequestWithCredentials_Pass_WithValidConfig(t *testing.T) {
|
||||||
|
Config = &config.Configuration{
|
||||||
|
Credentials: []config.Credential{
|
||||||
|
config.Credential{
|
||||||
|
ClientID: "client-id-for-test",
|
||||||
|
ClientSecret: "client-secret-for-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=")
|
||||||
|
if !pass {
|
||||||
|
t.Errorf("pass should be true, but false")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if clientId != "client-id-for-test" {
|
||||||
|
t.Errorf("client id should be \"%s\", but \"%s\"", "client-id-for-test", clientId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckRequestWithCredentials_Failed_WithInvalidConfig(t *testing.T) {
|
||||||
|
Config = &config.Configuration{
|
||||||
|
Credentials: []config.Credential{
|
||||||
|
config.Credential{
|
||||||
|
ClientID: "client-id-for-test",
|
||||||
|
ClientSecret: "invalid-client-secret-for-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=")
|
||||||
|
if pass {
|
||||||
|
t.Errorf("pass should be false, but true")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if clientId != "" {
|
||||||
|
t.Errorf("client id should be empty")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user