mirror of
				https://github.com/eryajf/chatgpt-dingtalk.git
				synced 2025-10-31 11:36:17 +08:00 
			
		
		
		
	feat: 支持上传图片到钉钉平台,在图片生成流程中使用钉钉的图片 CDN 能力 (#225)
This commit is contained in:
		
							
								
								
									
										10
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README.md
									
									
									
									
									
								
							| @@ -225,6 +225,7 @@ $ docker run -itd --name chatgpt -p 8090:8090 \ | |||||||
|   -e SENSITIVE_WORDS="aa,bb" \ |   -e SENSITIVE_WORDS="aa,bb" \ | ||||||
|   -e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \ |   -e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \ | ||||||
|   -e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \ |   -e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \ | ||||||
|  |   -e DINGTALK_CREDENTIALS="your_client_id1:secret1,your_client_id2:secret2" \ | ||||||
|   -e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/) |   -e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/) | ||||||
|   ,觉得不错你可以来波素质三连."  \ |   ,觉得不错你可以来波素质三连."  \ | ||||||
|   --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest |   --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest | ||||||
| @@ -541,6 +542,15 @@ azure_resource_name: "xxxx" | |||||||
| azure_deployment_name: "xxxx" | azure_deployment_name: "xxxx" | ||||||
| azure_openai_token: "xxxx" | azure_openai_token: "xxxx" | ||||||
|  |  | ||||||
|  | # 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息 | ||||||
|  | # 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署 | ||||||
|  | # client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret | ||||||
|  | # 建议采用 credentials 代替 app_secrets 配置项,以获得钉钉 OpenAPI 访问能力 | ||||||
|  | credentials: | ||||||
|  |   - | ||||||
|  |     client_id: "put-your-client-id-here" | ||||||
|  |     client_secret: "put-your-client-secret-here" | ||||||
|  |  | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ## 常见问题 | ## 常见问题 | ||||||
|   | |||||||
| @@ -59,3 +59,10 @@ azure_resource_name: "xxxx" | |||||||
| azure_deployment_name: "xxxx" | azure_deployment_name: "xxxx" | ||||||
| azure_openai_token: "xxxx" | azure_openai_token: "xxxx" | ||||||
|  |  | ||||||
|  | # 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息 | ||||||
|  | # 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署 | ||||||
|  | # client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret | ||||||
|  | #credentials: | ||||||
|  | #  - | ||||||
|  | #    client_id: "put-your-client-id-here" | ||||||
|  | #    client_secret: "put-your-client-secret-here" | ||||||
|   | |||||||
| @@ -14,6 +14,11 @@ import ( | |||||||
| 	"gopkg.in/yaml.v2" | 	"gopkg.in/yaml.v2" | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | type Credential struct { | ||||||
|  | 	ClientID     string `yaml:"client_id"` | ||||||
|  | 	ClientSecret string `yaml:"client_secret"` | ||||||
|  | } | ||||||
|  |  | ||||||
| // Configuration 项目配置 | // Configuration 项目配置 | ||||||
| type Configuration struct { | type Configuration struct { | ||||||
| 	// 日志级别,info或者debug | 	// 日志级别,info或者debug | ||||||
| @@ -62,6 +67,8 @@ type Configuration struct { | |||||||
| 	AzureResourceName   string `yaml:"azure_resource_name"` | 	AzureResourceName   string `yaml:"azure_resource_name"` | ||||||
| 	AzureDeploymentName string `yaml:"azure_deployment_name"` | 	AzureDeploymentName string `yaml:"azure_deployment_name"` | ||||||
| 	AzureOpenAIToken    string `yaml:"azure_openai_token"` | 	AzureOpenAIToken    string `yaml:"azure_openai_token"` | ||||||
|  | 	// 钉钉应用鉴权凭据 | ||||||
|  | 	Credentials []Credential `yaml:"credentials"` | ||||||
| } | } | ||||||
|  |  | ||||||
| var config *Configuration | var config *Configuration | ||||||
| @@ -190,6 +197,18 @@ func LoadConfig() *Configuration { | |||||||
| 		if azureOpenaiToken != "" { | 		if azureOpenaiToken != "" { | ||||||
| 			config.AzureOpenAIToken = azureOpenaiToken | 			config.AzureOpenAIToken = azureOpenaiToken | ||||||
| 		} | 		} | ||||||
|  | 		credentials := os.Getenv("DINGTALK_CREDENTIALS") | ||||||
|  | 		if credentials != "" { | ||||||
|  | 			if config.Credentials == nil { | ||||||
|  | 				config.Credentials = []Credential{} | ||||||
|  | 			} | ||||||
|  | 			for _, idSecret := range strings.Split(credentials, ",") { | ||||||
|  | 				items := strings.SplitN(idSecret, ":", 2) | ||||||
|  | 				if len(items) == 2 { | ||||||
|  | 					config.Credentials = append(config.Credentials, Credential{ClientID: items[0], ClientSecret: items[1]}) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -37,6 +37,7 @@ services: | |||||||
|       AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai" |       AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai" | ||||||
|       AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai" |       AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai" | ||||||
|       AZURE_OPENAI_TOKEN: "" # Azure token |       AZURE_OPENAI_TOKEN: "" # Azure token | ||||||
|  |       DINGTALK_CREDENTIALS: "" # 钉钉应用访问凭证,比如 "client_id1:secret1,client_id2:secret2" | ||||||
|       HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义 |       HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义 | ||||||
|     volumes: |     volumes: | ||||||
|       - ./data:/app/data |       - ./data:/app/data | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								main.go
									
									
									
									
									
								
							| @@ -33,6 +33,14 @@ func Start() { | |||||||
| 			return | 			return | ||||||
| 		} | 		} | ||||||
| 		// 先校验回调是否合法 | 		// 先校验回调是否合法 | ||||||
|  | 		clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign")) | ||||||
|  | 		if !checkOk { | ||||||
|  | 			logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!") | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		// 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI | ||||||
|  | 		c.Set(public.DingTalkClientIdKeyName, clientId) | ||||||
|  | 		// 为了兼容存量老用户,暂时保留 public.CheckRequest 方法,将来升级到 Stream 模式后,建议去除该方法,采用上面的 CheckRequestWithCredentials | ||||||
| 		if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" { | 		if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" { | ||||||
| 			logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!") | 			logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!") | ||||||
| 			return | 			return | ||||||
| @@ -114,7 +122,7 @@ func Start() { | |||||||
| 			// 除去帮助之外的逻辑分流在这里处理 | 			// 除去帮助之外的逻辑分流在这里处理 | ||||||
| 			switch { | 			switch { | ||||||
| 			case strings.HasPrefix(msgObj.Text.Content, "#图片"): | 			case strings.HasPrefix(msgObj.Text.Content, "#图片"): | ||||||
| 				err := process.ImageGenerate(&msgObj) | 				err := process.ImageGenerate(c, &msgObj) | ||||||
| 				if err != nil { | 				if err != nil { | ||||||
| 					logger.Warning(fmt.Errorf("process request: %v", err)) | 					logger.Warning(fmt.Errorf("process request: %v", err)) | ||||||
| 					return | 					return | ||||||
|   | |||||||
| @@ -2,8 +2,12 @@ package chatgpt | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"context" | ||||||
| 	"encoding/base64" | 	"encoding/base64" | ||||||
| 	"encoding/gob" | 	"encoding/gob" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot" | ||||||
| 	"github.com/pandodao/tokenizer-go" | 	"github.com/pandodao/tokenizer-go" | ||||||
| 	"image/png" | 	"image/png" | ||||||
| 	"os" | 	"os" | ||||||
| @@ -218,7 +222,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) { | |||||||
| 		return resp.Choices[0].Text, nil | 		return resp.Choices[0].Text, nil | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| func (c *ChatGPT) GenreateImage(prompt string) (string, error) { | func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) { | ||||||
| 	model := public.Config.Model | 	model := public.Config.Model | ||||||
| 	if model == openai.GPT3Dot5Turbo0301 || | 	if model == openai.GPT3Dot5Turbo0301 || | ||||||
| 		model == openai.GPT3Dot5Turbo || | 		model == openai.GPT3Dot5Turbo || | ||||||
| @@ -247,6 +251,13 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) { | |||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		imageName := time.Now().Format("20060102-150405") + ".png" | 		imageName := time.Now().Format("20060102-150405") + ".png" | ||||||
|  | 		clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string) | ||||||
|  | 		client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId) | ||||||
|  | 		mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId)) | ||||||
|  | 		if client != nil { | ||||||
|  | 			mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng) | ||||||
|  | 		} | ||||||
|  |  | ||||||
| 		err = os.MkdirAll("data/images", 0755) | 		err = os.MkdirAll("data/images", 0755) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return "", err | 			return "", err | ||||||
| @@ -260,9 +271,12 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) { | |||||||
| 		if err := png.Encode(file, imgData); err != nil { | 		if err := png.Encode(file, imgData); err != nil { | ||||||
| 			return "", err | 			return "", err | ||||||
| 		} | 		} | ||||||
|  | 		if uploadErr == nil { | ||||||
|  | 			return mediaResult.MediaID, nil | ||||||
|  | 		} else { | ||||||
| 			return public.Config.ServiceURL + "/images/" + imageName, nil | 			return public.Config.ServiceURL + "/images/" + imageName, nil | ||||||
| 		} | 		} | ||||||
|  | 	} | ||||||
| 	return "", nil | 	return "", nil | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,6 +1,7 @@ | |||||||
| package chatgpt | package chatgpt | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"time" | 	"time" | ||||||
|  |  | ||||||
| 	"github.com/avast/retry-go" | 	"github.com/avast/retry-go" | ||||||
| @@ -58,7 +59,7 @@ func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error | |||||||
| } | } | ||||||
|  |  | ||||||
| // ImageQa 生成图片 | // ImageQa 生成图片 | ||||||
| func ImageQa(question, userId string) (answer string, err error) { | func ImageQa(ctx context.Context, question, userId string) (answer string, err error) { | ||||||
| 	chat := New(userId) | 	chat := New(userId) | ||||||
| 	defer chat.Close() | 	defer chat.Close() | ||||||
| 	// 定义一个重试策略 | 	// 定义一个重试策略 | ||||||
| @@ -70,7 +71,7 @@ func ImageQa(question, userId string) (answer string, err error) { | |||||||
| 	// 使用重试策略进行重试 | 	// 使用重试策略进行重试 | ||||||
| 	err = retry.Do( | 	err = retry.Do( | ||||||
| 		func() error { | 		func() error { | ||||||
| 			answer, err = chat.GenreateImage(question) | 			answer, err = chat.GenreateImage(ctx, question) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				return err | 				return err | ||||||
| 			} | 			} | ||||||
|   | |||||||
							
								
								
									
										213
									
								
								pkg/dingbot/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								pkg/dingbot/client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,213 @@ | |||||||
|  | package dingbot | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"github.com/eryajf/chatgpt-dingtalk/config" | ||||||
|  | 	"io" | ||||||
|  | 	"mime/multipart" | ||||||
|  | 	"net/http" | ||||||
|  | 	url2 "net/url" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files | ||||||
|  | const ( | ||||||
|  | 	MediaTypeImage string = "image" | ||||||
|  | 	MediaTypeVoice string = "voice" | ||||||
|  | 	MediaTypeVideo string = "video" | ||||||
|  | 	MediaTypeFile  string = "file" | ||||||
|  | ) | ||||||
|  | const ( | ||||||
|  | 	MimeTypeImagePng string = "image/png" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type MediaUploadResult struct { | ||||||
|  | 	ErrorCode    int64  `json:"errcode"` | ||||||
|  | 	ErrorMessage string `json:"errmsg"` | ||||||
|  | 	MediaID      string `json:"media_id"` | ||||||
|  | 	CreatedAt    int64  `json:"created_at"` | ||||||
|  | 	Type         string `json:"type"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type OAuthTokenResult struct { | ||||||
|  | 	ErrorCode    int    `json:"errcode"` | ||||||
|  | 	ErrorMessage string `json:"errmsg"` | ||||||
|  | 	AccessToken  string `json:"access_token"` | ||||||
|  | 	ExpiresIn    int    `json:"expires_in"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type DingTalkClientInterface interface { | ||||||
|  | 	GetAccessToken() (string, error) | ||||||
|  | 	UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type DingTalkClientManagerInterface interface { | ||||||
|  | 	GetClientByOAuthClientID(clientId string) DingTalkClientInterface | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type DingTalkClient struct { | ||||||
|  | 	Credential  config.Credential | ||||||
|  | 	AccessToken string | ||||||
|  | 	expireAt    int64 | ||||||
|  | 	mutex       sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type DingTalkClientManager struct { | ||||||
|  | 	Credentials []config.Credential | ||||||
|  | 	Clients     map[string]*DingTalkClient | ||||||
|  | 	mutex       sync.Mutex | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewDingTalkClient(credential config.Credential) *DingTalkClient { | ||||||
|  | 	return &DingTalkClient{ | ||||||
|  | 		Credential: credential, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewDingTalkClientManager(conf *config.Configuration) *DingTalkClientManager { | ||||||
|  | 	clients := make(map[string]*DingTalkClient) | ||||||
|  |  | ||||||
|  | 	if conf != nil && conf.Credentials != nil { | ||||||
|  | 		for _, credential := range conf.Credentials { | ||||||
|  | 			clients[credential.ClientID] = NewDingTalkClient(credential) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return &DingTalkClientManager{ | ||||||
|  | 		Credentials: conf.Credentials, | ||||||
|  | 		Clients:     clients, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m *DingTalkClientManager) GetClientByOAuthClientID(clientId string) DingTalkClientInterface { | ||||||
|  | 	m.mutex.Lock() | ||||||
|  | 	defer m.mutex.Unlock() | ||||||
|  | 	if client, ok := m.Clients[clientId]; ok { | ||||||
|  | 		return client | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *DingTalkClient) GetAccessToken() (string, error) { | ||||||
|  | 	accessToken := "" | ||||||
|  | 	{ | ||||||
|  | 		// 先查询缓存 | ||||||
|  | 		c.mutex.Lock() | ||||||
|  | 		now := time.Now().Unix() | ||||||
|  | 		if c.expireAt > 0 && c.AccessToken != "" && (now+60) < c.expireAt { | ||||||
|  | 			// 预留一分钟有效期避免在Token过期的临界点调用接口出现401错误 | ||||||
|  | 			accessToken = c.AccessToken | ||||||
|  | 		} | ||||||
|  | 		c.mutex.Unlock() | ||||||
|  | 	} | ||||||
|  | 	if accessToken != "" { | ||||||
|  | 		return accessToken, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	tokenResult, err := c.getAccessTokenFromDingTalk() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	{ | ||||||
|  | 		// 更新缓存 | ||||||
|  | 		c.mutex.Lock() | ||||||
|  | 		c.AccessToken = tokenResult.AccessToken | ||||||
|  | 		c.expireAt = time.Now().Unix() + int64(tokenResult.ExpiresIn) | ||||||
|  | 		c.mutex.Unlock() | ||||||
|  | 	} | ||||||
|  | 	return tokenResult.AccessToken, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *DingTalkClient) UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) { | ||||||
|  | 	// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files | ||||||
|  | 	accessToken, err := c.GetAccessToken() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if len(accessToken) == 0 { | ||||||
|  | 		return nil, errors.New("empty access token") | ||||||
|  | 	} | ||||||
|  | 	body := &bytes.Buffer{} | ||||||
|  | 	writer := multipart.NewWriter(body) | ||||||
|  | 	part, err := writer.CreateFormFile("media", filename) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	_, err = part.Write(content) | ||||||
|  | 	writer.WriteField("type", mediaType) | ||||||
|  | 	err = writer.Close() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create a new HTTP request to upload the media file | ||||||
|  | 	url := fmt.Sprintf("https://oapi.dingtalk.com/media/upload?access_token=%s", url2.QueryEscape(accessToken)) | ||||||
|  | 	req, err := http.NewRequest("POST", url, body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set("Content-Type", writer.FormDataContentType()) | ||||||
|  |  | ||||||
|  | 	// Send the HTTP request and parse the response | ||||||
|  | 	client := &http.Client{ | ||||||
|  | 		Timeout: time.Second * 60, | ||||||
|  | 	} | ||||||
|  | 	res, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	// Parse the response body as JSON and extract the media ID | ||||||
|  | 	media := &MediaUploadResult{} | ||||||
|  | 	bodyBytes, err := io.ReadAll(res.Body) | ||||||
|  | 	json.Unmarshal(bodyBytes, media) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if media.ErrorCode != 0 { | ||||||
|  | 		return nil, errors.New(media.ErrorMessage) | ||||||
|  | 	} | ||||||
|  | 	return media, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *DingTalkClient) getAccessTokenFromDingTalk() (*OAuthTokenResult, error) { | ||||||
|  | 	// OpenAPI doc: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token | ||||||
|  | 	apiUrl := "https://oapi.dingtalk.com/gettoken" | ||||||
|  | 	queryParams := url2.Values{} | ||||||
|  | 	queryParams.Add("appkey", c.Credential.ClientID) | ||||||
|  | 	queryParams.Add("appsecret", c.Credential.ClientSecret) | ||||||
|  |  | ||||||
|  | 	// Create a new HTTP request to get the AccessToken | ||||||
|  | 	req, err := http.NewRequest("GET", apiUrl+"?"+queryParams.Encode(), nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Send the HTTP request and parse the response body as JSON | ||||||
|  | 	client := http.Client{ | ||||||
|  | 		Timeout: time.Second * 60, | ||||||
|  | 	} | ||||||
|  | 	res, err := client.Do(req) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	defer res.Body.Close() | ||||||
|  | 	body, err := io.ReadAll(res.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	tokenResult := &OAuthTokenResult{} | ||||||
|  | 	err = json.Unmarshal(body, tokenResult) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if tokenResult.ErrorCode != 0 { | ||||||
|  | 		return nil, errors.New(tokenResult.ErrorMessage) | ||||||
|  | 	} | ||||||
|  | 	return tokenResult, nil | ||||||
|  | } | ||||||
							
								
								
									
										53
									
								
								pkg/dingbot/client_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								pkg/dingbot/client_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | package dingbot | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"github.com/eryajf/chatgpt-dingtalk/config" | ||||||
|  | 	"image" | ||||||
|  | 	"image/color" | ||||||
|  | 	"image/png" | ||||||
|  | 	"os" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestUploadMedia_Pass_WithValidConfig(t *testing.T) { | ||||||
|  | 	// 设置了钉钉 ClientID 和 ClientSecret 的环境变量才执行以下测试,用于快速验证钉钉图片上传能力 | ||||||
|  | 	clientId, clientSecret := os.Getenv("DINGTALK_CLIENT_ID_FOR_TEST"), os.Getenv("DINGTALK_CLIENT_SECRET_FOR_TEST") | ||||||
|  | 	if len(clientId) <= 0 || len(clientSecret) <= 0 { | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	credentials := []config.Credential{ | ||||||
|  | 		config.Credential{ | ||||||
|  | 			ClientID:     clientId, | ||||||
|  | 			ClientSecret: clientSecret, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	client := NewDingTalkClientManager(&config.Configuration{Credentials: credentials}).GetClientByOAuthClientID(clientId) | ||||||
|  | 	var imageContent []byte | ||||||
|  | 	{ | ||||||
|  | 		// 生成一张用于测试的图片 | ||||||
|  | 		img := image.NewRGBA(image.Rect(0, 0, 200, 100)) | ||||||
|  | 		blue := color.RGBA{0, 0, 255, 255} | ||||||
|  | 		for x := 0; x < img.Bounds().Dx(); x++ { | ||||||
|  | 			for y := 0; y < img.Bounds().Dy(); y++ { | ||||||
|  | 				img.Set(x, y, blue) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		buf := new(bytes.Buffer) | ||||||
|  | 		err := png.Encode(buf, img) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		// get the byte array from the buffer | ||||||
|  | 		imageContent = buf.Bytes() | ||||||
|  | 	} | ||||||
|  | 	result, err := client.UploadMedia(imageContent, "filename.png", "image", "image/png") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("upload media failed, err=%s", err.Error()) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if result.MediaID == "" { | ||||||
|  | 		t.Errorf("upload media failed, empty media id") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,6 +1,7 @@ | |||||||
| package process | package process | ||||||
|  |  | ||||||
| import ( | import ( | ||||||
|  | 	"context" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"github.com/eryajf/chatgpt-dingtalk/public" | 	"github.com/eryajf/chatgpt-dingtalk/public" | ||||||
| 	"strings" | 	"strings" | ||||||
| @@ -12,7 +13,7 @@ import ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| // ImageGenerate openai生成图片 | // ImageGenerate openai生成图片 | ||||||
| func ImageGenerate(rmsg *dingbot.ReceiveMsg) error { | func ImageGenerate(ctx context.Context, rmsg *dingbot.ReceiveMsg) error { | ||||||
| 	if public.Config.AzureOn { | 	if public.Config.AzureOn { | ||||||
| 		_, err := rmsg.ReplyToDingtalk(string(dingbot. | 		_, err := rmsg.ReplyToDingtalk(string(dingbot. | ||||||
| 			MARKDOWN), "azure 模式下暂不支持图片创作功能") | 			MARKDOWN), "azure 模式下暂不支持图片创作功能") | ||||||
| @@ -32,7 +33,7 @@ func ImageGenerate(rmsg *dingbot.ReceiveMsg) error { | |||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Error("往MySQL新增数据失败,错误信息:", err) | 		logger.Error("往MySQL新增数据失败,错误信息:", err) | ||||||
| 	} | 	} | ||||||
| 	reply, err := chatgpt.ImageQa(rmsg.Text.Content, rmsg.GetSenderIdentifier()) | 	reply, err := chatgpt.ImageQa(ctx, rmsg.Text.Content, rmsg.GetSenderIdentifier()) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		logger.Info(fmt.Errorf("gpt request error: %v", err)) | 		logger.Info(fmt.Errorf("gpt request error: %v", err)) | ||||||
| 		_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err)) | 		_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err)) | ||||||
|   | |||||||
| @@ -4,12 +4,16 @@ import ( | |||||||
| 	"github.com/eryajf/chatgpt-dingtalk/config" | 	"github.com/eryajf/chatgpt-dingtalk/config" | ||||||
| 	"github.com/eryajf/chatgpt-dingtalk/pkg/cache" | 	"github.com/eryajf/chatgpt-dingtalk/pkg/cache" | ||||||
| 	"github.com/eryajf/chatgpt-dingtalk/pkg/db" | 	"github.com/eryajf/chatgpt-dingtalk/pkg/db" | ||||||
|  | 	"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot" | ||||||
| 	"github.com/sashabaranov/go-openai" | 	"github.com/sashabaranov/go-openai" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| var UserService cache.UserServiceInterface | var UserService cache.UserServiceInterface | ||||||
| var Config *config.Configuration | var Config *config.Configuration | ||||||
| var Prompt *[]config.Prompt | var Prompt *[]config.Prompt | ||||||
|  | var DingTalkClientManager dingbot.DingTalkClientManagerInterface | ||||||
|  |  | ||||||
|  | const DingTalkClientIdKeyName = "DingTalkClientId" | ||||||
|  |  | ||||||
| func InitSvc() { | func InitSvc() { | ||||||
| 	// 加载配置 | 	// 加载配置 | ||||||
| @@ -18,6 +22,8 @@ func InitSvc() { | |||||||
| 	Prompt = config.LoadPrompt() | 	Prompt = config.LoadPrompt() | ||||||
| 	// 初始化缓存 | 	// 初始化缓存 | ||||||
| 	UserService = cache.NewUserService() | 	UserService = cache.NewUserService() | ||||||
|  | 	// 初始化钉钉开放平台的客户端,用于访问上传图片等能力 | ||||||
|  | 	DingTalkClientManager = dingbot.NewDingTalkClientManager(Config) | ||||||
| 	// 初始化数据库 | 	// 初始化数据库 | ||||||
| 	db.InitDB() | 	db.InitDB() | ||||||
| 	// 暂时不在初始化时获取余额 | 	// 暂时不在初始化时获取余额 | ||||||
|   | |||||||
| @@ -124,6 +124,23 @@ func GetReadTime(t time.Time) string { | |||||||
| 	return t.Format("2006-01-02 15:04:05") | 	return t.Format("2006-01-02 15:04:05") | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func CheckRequestWithCredentials(ts, sg string) (clientId string, pass bool) { | ||||||
|  | 	clientId, pass = "", false | ||||||
|  | 	credentials := Config.Credentials | ||||||
|  | 	if credentials == nil || len(credentials) == 0 { | ||||||
|  | 		return "", true | ||||||
|  | 	} | ||||||
|  | 	for _, credential := range Config.Credentials { | ||||||
|  | 		stringToSign := fmt.Sprintf("%s\n%s", ts, credential.ClientSecret) | ||||||
|  | 		mac := hmac.New(sha256.New, []byte(credential.ClientSecret)) | ||||||
|  | 		_, _ = mac.Write([]byte(stringToSign)) | ||||||
|  | 		if base64.StdEncoding.EncodeToString(mac.Sum(nil)) == sg { | ||||||
|  | 			return credential.ClientID, true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
| func CheckRequest(ts, sg string) bool { | func CheckRequest(ts, sg string) bool { | ||||||
| 	appSecrets := Config.AppSecrets | 	appSecrets := Config.AppSecrets | ||||||
| 	// 如果没有指定或者outgoing类型机器人下使用,则默认不做校验 | 	// 如果没有指定或者outgoing类型机器人下使用,则默认不做校验 | ||||||
|   | |||||||
							
								
								
									
										76
									
								
								public/tools_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								public/tools_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | |||||||
|  | package public | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"github.com/eryajf/chatgpt-dingtalk/config" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestCheckRequestWithCredentials_Pass_WithNilConfig(t *testing.T) { | ||||||
|  | 	Config = &config.Configuration{ | ||||||
|  | 		Credentials: nil, | ||||||
|  | 	} | ||||||
|  | 	clientId, pass := CheckRequestWithCredentials("ts", "sg") | ||||||
|  | 	if !pass { | ||||||
|  | 		t.Errorf("pass should be true, but false") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if len(clientId) > 0 { | ||||||
|  | 		t.Errorf("client id should be empty") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestCheckRequestWithCredentials_Pass_WithEmptyConfig(t *testing.T) { | ||||||
|  | 	Config = &config.Configuration{ | ||||||
|  | 		Credentials: []config.Credential{}, | ||||||
|  | 	} | ||||||
|  | 	clientId, pass := CheckRequestWithCredentials("ts", "sg") | ||||||
|  | 	if !pass { | ||||||
|  | 		t.Errorf("pass should be true, but false") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if len(clientId) > 0 { | ||||||
|  | 		t.Errorf("client id should be empty") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestCheckRequestWithCredentials_Pass_WithValidConfig(t *testing.T) { | ||||||
|  | 	Config = &config.Configuration{ | ||||||
|  | 		Credentials: []config.Credential{ | ||||||
|  | 			config.Credential{ | ||||||
|  | 				ClientID:     "client-id-for-test", | ||||||
|  | 				ClientSecret: "client-secret-for-test", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=") | ||||||
|  | 	if !pass { | ||||||
|  | 		t.Errorf("pass should be true, but false") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if clientId != "client-id-for-test" { | ||||||
|  | 		t.Errorf("client id should be \"%s\", but \"%s\"", "client-id-for-test", clientId) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestCheckRequestWithCredentials_Failed_WithInvalidConfig(t *testing.T) { | ||||||
|  | 	Config = &config.Configuration{ | ||||||
|  | 		Credentials: []config.Credential{ | ||||||
|  | 			config.Credential{ | ||||||
|  | 				ClientID:     "client-id-for-test", | ||||||
|  | 				ClientSecret: "invalid-client-secret-for-test", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 	clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=") | ||||||
|  | 	if pass { | ||||||
|  | 		t.Errorf("pass should be false, but true") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | 	if clientId != "" { | ||||||
|  | 		t.Errorf("client id should be empty") | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user
	 金喜@DingTalk
					金喜@DingTalk