mirror of
				https://github.com/eryajf/chatgpt-dingtalk.git
				synced 2025-10-31 11:36:17 +08:00 
			
		
		
		
	feat: 支持上传图片到钉钉平台,在图片生成流程中使用钉钉的图片 CDN 能力 (#225)
This commit is contained in:
		
							
								
								
									
										10
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								README.md
									
									
									
									
									
								
							| @@ -225,6 +225,7 @@ $ docker run -itd --name chatgpt -p 8090:8090 \ | ||||
|   -e SENSITIVE_WORDS="aa,bb" \ | ||||
|   -e AZURE_ON="false" -e AZURE_API_VERSION="" -e AZURE_RESOURCE_NAME="" \ | ||||
|   -e AZURE_DEPLOYMENT_NAME="" -e AZURE_OPENAI_TOKEN="" \ | ||||
|   -e DINGTALK_CREDENTIALS="your_client_id1:secret1,your_client_id2:secret2" \ | ||||
|   -e HELP="欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/) | ||||
|   ,觉得不错你可以来波素质三连."  \ | ||||
|   --restart=always  dockerproxy.com/eryajf/chatgpt-dingtalk:latest | ||||
| @@ -541,6 +542,15 @@ azure_resource_name: "xxxx" | ||||
| azure_deployment_name: "xxxx" | ||||
| azure_openai_token: "xxxx" | ||||
|  | ||||
| # 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息 | ||||
| # 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署 | ||||
| # client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret | ||||
| # 建议采用 credentials 代替 app_secrets 配置项,以获得钉钉 OpenAPI 访问能力 | ||||
| credentials: | ||||
|   - | ||||
|     client_id: "put-your-client-id-here" | ||||
|     client_secret: "put-your-client-secret-here" | ||||
|  | ||||
| ``` | ||||
|  | ||||
| ## 常见问题 | ||||
|   | ||||
| @@ -59,3 +59,10 @@ azure_resource_name: "xxxx" | ||||
| azure_deployment_name: "xxxx" | ||||
| azure_openai_token: "xxxx" | ||||
|  | ||||
| # 钉钉应用鉴权凭据信息,支持多个应用。通过请求时候鉴权来识别是来自哪个机器人应用的消息 | ||||
| # 设置credentials 之后,即具备了访问钉钉平台绝大部分 OpenAPI 的能力;例如上传图片到钉钉平台,提升图片体验,结合 Stream 模式简化服务部署 | ||||
| # client_id 对应钉钉平台 AppKey/SuiteKey;client_secret 对应 AppSecret/SuiteSecret | ||||
| #credentials: | ||||
| #  - | ||||
| #    client_id: "put-your-client-id-here" | ||||
| #    client_secret: "put-your-client-secret-here" | ||||
|   | ||||
| @@ -14,6 +14,11 @@ import ( | ||||
| 	"gopkg.in/yaml.v2" | ||||
| ) | ||||
|  | ||||
| type Credential struct { | ||||
| 	ClientID     string `yaml:"client_id"` | ||||
| 	ClientSecret string `yaml:"client_secret"` | ||||
| } | ||||
|  | ||||
| // Configuration 项目配置 | ||||
| type Configuration struct { | ||||
| 	// 日志级别,info或者debug | ||||
| @@ -62,6 +67,8 @@ type Configuration struct { | ||||
| 	AzureResourceName   string `yaml:"azure_resource_name"` | ||||
| 	AzureDeploymentName string `yaml:"azure_deployment_name"` | ||||
| 	AzureOpenAIToken    string `yaml:"azure_openai_token"` | ||||
| 	// 钉钉应用鉴权凭据 | ||||
| 	Credentials []Credential `yaml:"credentials"` | ||||
| } | ||||
|  | ||||
| var config *Configuration | ||||
| @@ -190,6 +197,18 @@ func LoadConfig() *Configuration { | ||||
| 		if azureOpenaiToken != "" { | ||||
| 			config.AzureOpenAIToken = azureOpenaiToken | ||||
| 		} | ||||
| 		credentials := os.Getenv("DINGTALK_CREDENTIALS") | ||||
| 		if credentials != "" { | ||||
| 			if config.Credentials == nil { | ||||
| 				config.Credentials = []Credential{} | ||||
| 			} | ||||
| 			for _, idSecret := range strings.Split(credentials, ",") { | ||||
| 				items := strings.SplitN(idSecret, ":", 2) | ||||
| 				if len(items) == 2 { | ||||
| 					config.Credentials = append(config.Credentials, Credential{ClientID: items[0], ClientSecret: items[1]}) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 	}) | ||||
|  | ||||
|   | ||||
| @@ -37,6 +37,7 @@ services: | ||||
|       AZURE_RESOURCE_NAME: "" # Azure OpenAi API 资源名称,比如 "openai" | ||||
|       AZURE_DEPLOYMENT_NAME: "" # Azure OpenAi API 部署名称,比如 "openai" | ||||
|       AZURE_OPENAI_TOKEN: "" # Azure token | ||||
|       DINGTALK_CREDENTIALS: "" # 钉钉应用访问凭证,比如 "client_id1:secret1,client_id2:secret2" | ||||
|       HELP: "欢迎使用本工具\n\n你可以查看:[用户指南](https://github.com/eryajf/chatgpt-dingtalk/blob/main/docs/userGuide.md)\n\n这是一个[开源项目](https://github.com/eryajf/chatgpt-dingtalk/),觉得不错你可以来波素质三连." # 帮助信息,放在配置文件,可供自定义 | ||||
|     volumes: | ||||
|       - ./data:/app/data | ||||
|   | ||||
							
								
								
									
										10
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								main.go
									
									
									
									
									
								
							| @@ -33,6 +33,14 @@ func Start() { | ||||
| 			return | ||||
| 		} | ||||
| 		// 先校验回调是否合法 | ||||
| 		clientId, checkOk := public.CheckRequestWithCredentials(c.GetHeader("timestamp"), c.GetHeader("sign")) | ||||
| 		if !checkOk { | ||||
| 			logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!") | ||||
| 			return | ||||
| 		} | ||||
| 		// 通过 context 传递 OAuth ClientID,用于后续流程中调用钉钉OpenAPI | ||||
| 		c.Set(public.DingTalkClientIdKeyName, clientId) | ||||
| 		// 为了兼容存量老用户,暂时保留 public.CheckRequest 方法,将来升级到 Stream 模式后,建议去除该方法,采用上面的 CheckRequestWithCredentials | ||||
| 		if !public.CheckRequest(c.GetHeader("timestamp"), c.GetHeader("sign")) && msgObj.SenderStaffId != "" { | ||||
| 			logger.Warning("该请求不合法,可能是其他企业或者未经允许的应用调用所致,请知悉!") | ||||
| 			return | ||||
| @@ -114,7 +122,7 @@ func Start() { | ||||
| 			// 除去帮助之外的逻辑分流在这里处理 | ||||
| 			switch { | ||||
| 			case strings.HasPrefix(msgObj.Text.Content, "#图片"): | ||||
| 				err := process.ImageGenerate(&msgObj) | ||||
| 				err := process.ImageGenerate(c, &msgObj) | ||||
| 				if err != nil { | ||||
| 					logger.Warning(fmt.Errorf("process request: %v", err)) | ||||
| 					return | ||||
|   | ||||
| @@ -2,8 +2,12 @@ package chatgpt | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"context" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/gob" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot" | ||||
| 	"github.com/pandodao/tokenizer-go" | ||||
| 	"image/png" | ||||
| 	"os" | ||||
| @@ -218,7 +222,7 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) { | ||||
| 		return resp.Choices[0].Text, nil | ||||
| 	} | ||||
| } | ||||
| func (c *ChatGPT) GenreateImage(prompt string) (string, error) { | ||||
| func (c *ChatGPT) GenreateImage(ctx context.Context, prompt string) (string, error) { | ||||
| 	model := public.Config.Model | ||||
| 	if model == openai.GPT3Dot5Turbo0301 || | ||||
| 		model == openai.GPT3Dot5Turbo || | ||||
| @@ -247,6 +251,13 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) { | ||||
| 		} | ||||
|  | ||||
| 		imageName := time.Now().Format("20060102-150405") + ".png" | ||||
| 		clientId, _ := ctx.Value(public.DingTalkClientIdKeyName).(string) | ||||
| 		client := public.DingTalkClientManager.GetClientByOAuthClientID(clientId) | ||||
| 		mediaResult, uploadErr := &dingbot.MediaUploadResult{}, errors.New(fmt.Sprintf("unknown clientId: %s", clientId)) | ||||
| 		if client != nil { | ||||
| 			mediaResult, uploadErr = client.UploadMedia(imgBytes, imageName, dingbot.MediaTypeImage, dingbot.MimeTypeImagePng) | ||||
| 		} | ||||
|  | ||||
| 		err = os.MkdirAll("data/images", 0755) | ||||
| 		if err != nil { | ||||
| 			return "", err | ||||
| @@ -260,9 +271,12 @@ func (c *ChatGPT) GenreateImage(prompt string) (string, error) { | ||||
| 		if err := png.Encode(file, imgData); err != nil { | ||||
| 			return "", err | ||||
| 		} | ||||
|  | ||||
| 		if uploadErr == nil { | ||||
| 			return mediaResult.MediaID, nil | ||||
| 		} else { | ||||
| 			return public.Config.ServiceURL + "/images/" + imageName, nil | ||||
| 		} | ||||
| 	} | ||||
| 	return "", nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| package chatgpt | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/avast/retry-go" | ||||
| @@ -58,7 +59,7 @@ func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error | ||||
| } | ||||
|  | ||||
| // ImageQa 生成图片 | ||||
| func ImageQa(question, userId string) (answer string, err error) { | ||||
| func ImageQa(ctx context.Context, question, userId string) (answer string, err error) { | ||||
| 	chat := New(userId) | ||||
| 	defer chat.Close() | ||||
| 	// 定义一个重试策略 | ||||
| @@ -70,7 +71,7 @@ func ImageQa(question, userId string) (answer string, err error) { | ||||
| 	// 使用重试策略进行重试 | ||||
| 	err = retry.Do( | ||||
| 		func() error { | ||||
| 			answer, err = chat.GenreateImage(question) | ||||
| 			answer, err = chat.GenreateImage(ctx, question) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|   | ||||
							
								
								
									
										213
									
								
								pkg/dingbot/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										213
									
								
								pkg/dingbot/client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,213 @@ | ||||
| package dingbot | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/config" | ||||
| 	"io" | ||||
| 	"mime/multipart" | ||||
| 	"net/http" | ||||
| 	url2 "net/url" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files | ||||
| const ( | ||||
| 	MediaTypeImage string = "image" | ||||
| 	MediaTypeVoice string = "voice" | ||||
| 	MediaTypeVideo string = "video" | ||||
| 	MediaTypeFile  string = "file" | ||||
| ) | ||||
| const ( | ||||
| 	MimeTypeImagePng string = "image/png" | ||||
| ) | ||||
|  | ||||
| type MediaUploadResult struct { | ||||
| 	ErrorCode    int64  `json:"errcode"` | ||||
| 	ErrorMessage string `json:"errmsg"` | ||||
| 	MediaID      string `json:"media_id"` | ||||
| 	CreatedAt    int64  `json:"created_at"` | ||||
| 	Type         string `json:"type"` | ||||
| } | ||||
|  | ||||
| type OAuthTokenResult struct { | ||||
| 	ErrorCode    int    `json:"errcode"` | ||||
| 	ErrorMessage string `json:"errmsg"` | ||||
| 	AccessToken  string `json:"access_token"` | ||||
| 	ExpiresIn    int    `json:"expires_in"` | ||||
| } | ||||
|  | ||||
| type DingTalkClientInterface interface { | ||||
| 	GetAccessToken() (string, error) | ||||
| 	UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) | ||||
| } | ||||
|  | ||||
| type DingTalkClientManagerInterface interface { | ||||
| 	GetClientByOAuthClientID(clientId string) DingTalkClientInterface | ||||
| } | ||||
|  | ||||
| type DingTalkClient struct { | ||||
| 	Credential  config.Credential | ||||
| 	AccessToken string | ||||
| 	expireAt    int64 | ||||
| 	mutex       sync.Mutex | ||||
| } | ||||
|  | ||||
| type DingTalkClientManager struct { | ||||
| 	Credentials []config.Credential | ||||
| 	Clients     map[string]*DingTalkClient | ||||
| 	mutex       sync.Mutex | ||||
| } | ||||
|  | ||||
| func NewDingTalkClient(credential config.Credential) *DingTalkClient { | ||||
| 	return &DingTalkClient{ | ||||
| 		Credential: credential, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func NewDingTalkClientManager(conf *config.Configuration) *DingTalkClientManager { | ||||
| 	clients := make(map[string]*DingTalkClient) | ||||
|  | ||||
| 	if conf != nil && conf.Credentials != nil { | ||||
| 		for _, credential := range conf.Credentials { | ||||
| 			clients[credential.ClientID] = NewDingTalkClient(credential) | ||||
| 		} | ||||
| 	} | ||||
| 	return &DingTalkClientManager{ | ||||
| 		Credentials: conf.Credentials, | ||||
| 		Clients:     clients, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (m *DingTalkClientManager) GetClientByOAuthClientID(clientId string) DingTalkClientInterface { | ||||
| 	m.mutex.Lock() | ||||
| 	defer m.mutex.Unlock() | ||||
| 	if client, ok := m.Clients[clientId]; ok { | ||||
| 		return client | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (c *DingTalkClient) GetAccessToken() (string, error) { | ||||
| 	accessToken := "" | ||||
| 	{ | ||||
| 		// 先查询缓存 | ||||
| 		c.mutex.Lock() | ||||
| 		now := time.Now().Unix() | ||||
| 		if c.expireAt > 0 && c.AccessToken != "" && (now+60) < c.expireAt { | ||||
| 			// 预留一分钟有效期避免在Token过期的临界点调用接口出现401错误 | ||||
| 			accessToken = c.AccessToken | ||||
| 		} | ||||
| 		c.mutex.Unlock() | ||||
| 	} | ||||
| 	if accessToken != "" { | ||||
| 		return accessToken, nil | ||||
| 	} | ||||
|  | ||||
| 	tokenResult, err := c.getAccessTokenFromDingTalk() | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	{ | ||||
| 		// 更新缓存 | ||||
| 		c.mutex.Lock() | ||||
| 		c.AccessToken = tokenResult.AccessToken | ||||
| 		c.expireAt = time.Now().Unix() + int64(tokenResult.ExpiresIn) | ||||
| 		c.mutex.Unlock() | ||||
| 	} | ||||
| 	return tokenResult.AccessToken, nil | ||||
| } | ||||
|  | ||||
| func (c *DingTalkClient) UploadMedia(content []byte, filename, mediaType, mimeType string) (*MediaUploadResult, error) { | ||||
| 	// OpenAPI doc: https://open.dingtalk.com/document/isvapp/upload-media-files | ||||
| 	accessToken, err := c.GetAccessToken() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if len(accessToken) == 0 { | ||||
| 		return nil, errors.New("empty access token") | ||||
| 	} | ||||
| 	body := &bytes.Buffer{} | ||||
| 	writer := multipart.NewWriter(body) | ||||
| 	part, err := writer.CreateFormFile("media", filename) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	_, err = part.Write(content) | ||||
| 	writer.WriteField("type", mediaType) | ||||
| 	err = writer.Close() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Create a new HTTP request to upload the media file | ||||
| 	url := fmt.Sprintf("https://oapi.dingtalk.com/media/upload?access_token=%s", url2.QueryEscape(accessToken)) | ||||
| 	req, err := http.NewRequest("POST", url, body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	req.Header.Set("Content-Type", writer.FormDataContentType()) | ||||
|  | ||||
| 	// Send the HTTP request and parse the response | ||||
| 	client := &http.Client{ | ||||
| 		Timeout: time.Second * 60, | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	// Parse the response body as JSON and extract the media ID | ||||
| 	media := &MediaUploadResult{} | ||||
| 	bodyBytes, err := io.ReadAll(res.Body) | ||||
| 	json.Unmarshal(bodyBytes, media) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if media.ErrorCode != 0 { | ||||
| 		return nil, errors.New(media.ErrorMessage) | ||||
| 	} | ||||
| 	return media, nil | ||||
| } | ||||
|  | ||||
| func (c *DingTalkClient) getAccessTokenFromDingTalk() (*OAuthTokenResult, error) { | ||||
| 	// OpenAPI doc: https://open.dingtalk.com/document/orgapp/obtain-orgapp-token | ||||
| 	apiUrl := "https://oapi.dingtalk.com/gettoken" | ||||
| 	queryParams := url2.Values{} | ||||
| 	queryParams.Add("appkey", c.Credential.ClientID) | ||||
| 	queryParams.Add("appsecret", c.Credential.ClientSecret) | ||||
|  | ||||
| 	// Create a new HTTP request to get the AccessToken | ||||
| 	req, err := http.NewRequest("GET", apiUrl+"?"+queryParams.Encode(), nil) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Send the HTTP request and parse the response body as JSON | ||||
| 	client := http.Client{ | ||||
| 		Timeout: time.Second * 60, | ||||
| 	} | ||||
| 	res, err := client.Do(req) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	tokenResult := &OAuthTokenResult{} | ||||
| 	err = json.Unmarshal(body, tokenResult) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if tokenResult.ErrorCode != 0 { | ||||
| 		return nil, errors.New(tokenResult.ErrorMessage) | ||||
| 	} | ||||
| 	return tokenResult, nil | ||||
| } | ||||
							
								
								
									
										53
									
								
								pkg/dingbot/client_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								pkg/dingbot/client_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| package dingbot | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/config" | ||||
| 	"image" | ||||
| 	"image/color" | ||||
| 	"image/png" | ||||
| 	"os" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestUploadMedia_Pass_WithValidConfig(t *testing.T) { | ||||
| 	// 设置了钉钉 ClientID 和 ClientSecret 的环境变量才执行以下测试,用于快速验证钉钉图片上传能力 | ||||
| 	clientId, clientSecret := os.Getenv("DINGTALK_CLIENT_ID_FOR_TEST"), os.Getenv("DINGTALK_CLIENT_SECRET_FOR_TEST") | ||||
| 	if len(clientId) <= 0 || len(clientSecret) <= 0 { | ||||
| 		return | ||||
| 	} | ||||
| 	credentials := []config.Credential{ | ||||
| 		config.Credential{ | ||||
| 			ClientID:     clientId, | ||||
| 			ClientSecret: clientSecret, | ||||
| 		}, | ||||
| 	} | ||||
| 	client := NewDingTalkClientManager(&config.Configuration{Credentials: credentials}).GetClientByOAuthClientID(clientId) | ||||
| 	var imageContent []byte | ||||
| 	{ | ||||
| 		// 生成一张用于测试的图片 | ||||
| 		img := image.NewRGBA(image.Rect(0, 0, 200, 100)) | ||||
| 		blue := color.RGBA{0, 0, 255, 255} | ||||
| 		for x := 0; x < img.Bounds().Dx(); x++ { | ||||
| 			for y := 0; y < img.Bounds().Dy(); y++ { | ||||
| 				img.Set(x, y, blue) | ||||
| 			} | ||||
| 		} | ||||
| 		buf := new(bytes.Buffer) | ||||
| 		err := png.Encode(buf, img) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		// get the byte array from the buffer | ||||
| 		imageContent = buf.Bytes() | ||||
| 	} | ||||
| 	result, err := client.UploadMedia(imageContent, "filename.png", "image", "image/png") | ||||
| 	if err != nil { | ||||
| 		t.Errorf("upload media failed, err=%s", err.Error()) | ||||
| 		return | ||||
| 	} | ||||
| 	if result.MediaID == "" { | ||||
| 		t.Errorf("upload media failed, empty media id") | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
| @@ -1,6 +1,7 @@ | ||||
| package process | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/public" | ||||
| 	"strings" | ||||
| @@ -12,7 +13,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| // ImageGenerate openai生成图片 | ||||
| func ImageGenerate(rmsg *dingbot.ReceiveMsg) error { | ||||
| func ImageGenerate(ctx context.Context, rmsg *dingbot.ReceiveMsg) error { | ||||
| 	if public.Config.AzureOn { | ||||
| 		_, err := rmsg.ReplyToDingtalk(string(dingbot. | ||||
| 			MARKDOWN), "azure 模式下暂不支持图片创作功能") | ||||
| @@ -32,7 +33,7 @@ func ImageGenerate(rmsg *dingbot.ReceiveMsg) error { | ||||
| 	if err != nil { | ||||
| 		logger.Error("往MySQL新增数据失败,错误信息:", err) | ||||
| 	} | ||||
| 	reply, err := chatgpt.ImageQa(rmsg.Text.Content, rmsg.GetSenderIdentifier()) | ||||
| 	reply, err := chatgpt.ImageQa(ctx, rmsg.Text.Content, rmsg.GetSenderIdentifier()) | ||||
| 	if err != nil { | ||||
| 		logger.Info(fmt.Errorf("gpt request error: %v", err)) | ||||
| 		_, err = rmsg.ReplyToDingtalk(string(dingbot.TEXT), fmt.Sprintf("请求openai失败了,错误信息:%v", err)) | ||||
|   | ||||
| @@ -4,12 +4,16 @@ import ( | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/config" | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/pkg/cache" | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/pkg/db" | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/pkg/dingbot" | ||||
| 	"github.com/sashabaranov/go-openai" | ||||
| ) | ||||
|  | ||||
| var UserService cache.UserServiceInterface | ||||
| var Config *config.Configuration | ||||
| var Prompt *[]config.Prompt | ||||
| var DingTalkClientManager dingbot.DingTalkClientManagerInterface | ||||
|  | ||||
| const DingTalkClientIdKeyName = "DingTalkClientId" | ||||
|  | ||||
| func InitSvc() { | ||||
| 	// 加载配置 | ||||
| @@ -18,6 +22,8 @@ func InitSvc() { | ||||
| 	Prompt = config.LoadPrompt() | ||||
| 	// 初始化缓存 | ||||
| 	UserService = cache.NewUserService() | ||||
| 	// 初始化钉钉开放平台的客户端,用于访问上传图片等能力 | ||||
| 	DingTalkClientManager = dingbot.NewDingTalkClientManager(Config) | ||||
| 	// 初始化数据库 | ||||
| 	db.InitDB() | ||||
| 	// 暂时不在初始化时获取余额 | ||||
|   | ||||
| @@ -124,6 +124,23 @@ func GetReadTime(t time.Time) string { | ||||
| 	return t.Format("2006-01-02 15:04:05") | ||||
| } | ||||
|  | ||||
| func CheckRequestWithCredentials(ts, sg string) (clientId string, pass bool) { | ||||
| 	clientId, pass = "", false | ||||
| 	credentials := Config.Credentials | ||||
| 	if credentials == nil || len(credentials) == 0 { | ||||
| 		return "", true | ||||
| 	} | ||||
| 	for _, credential := range Config.Credentials { | ||||
| 		stringToSign := fmt.Sprintf("%s\n%s", ts, credential.ClientSecret) | ||||
| 		mac := hmac.New(sha256.New, []byte(credential.ClientSecret)) | ||||
| 		_, _ = mac.Write([]byte(stringToSign)) | ||||
| 		if base64.StdEncoding.EncodeToString(mac.Sum(nil)) == sg { | ||||
| 			return credential.ClientID, true | ||||
| 		} | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func CheckRequest(ts, sg string) bool { | ||||
| 	appSecrets := Config.AppSecrets | ||||
| 	// 如果没有指定或者outgoing类型机器人下使用,则默认不做校验 | ||||
|   | ||||
							
								
								
									
										76
									
								
								public/tools_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								public/tools_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,76 @@ | ||||
| package public | ||||
|  | ||||
| import ( | ||||
| 	"github.com/eryajf/chatgpt-dingtalk/config" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestCheckRequestWithCredentials_Pass_WithNilConfig(t *testing.T) { | ||||
| 	Config = &config.Configuration{ | ||||
| 		Credentials: nil, | ||||
| 	} | ||||
| 	clientId, pass := CheckRequestWithCredentials("ts", "sg") | ||||
| 	if !pass { | ||||
| 		t.Errorf("pass should be true, but false") | ||||
| 		return | ||||
| 	} | ||||
| 	if len(clientId) > 0 { | ||||
| 		t.Errorf("client id should be empty") | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCheckRequestWithCredentials_Pass_WithEmptyConfig(t *testing.T) { | ||||
| 	Config = &config.Configuration{ | ||||
| 		Credentials: []config.Credential{}, | ||||
| 	} | ||||
| 	clientId, pass := CheckRequestWithCredentials("ts", "sg") | ||||
| 	if !pass { | ||||
| 		t.Errorf("pass should be true, but false") | ||||
| 		return | ||||
| 	} | ||||
| 	if len(clientId) > 0 { | ||||
| 		t.Errorf("client id should be empty") | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCheckRequestWithCredentials_Pass_WithValidConfig(t *testing.T) { | ||||
| 	Config = &config.Configuration{ | ||||
| 		Credentials: []config.Credential{ | ||||
| 			config.Credential{ | ||||
| 				ClientID:     "client-id-for-test", | ||||
| 				ClientSecret: "client-secret-for-test", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=") | ||||
| 	if !pass { | ||||
| 		t.Errorf("pass should be true, but false") | ||||
| 		return | ||||
| 	} | ||||
| 	if clientId != "client-id-for-test" { | ||||
| 		t.Errorf("client id should be \"%s\", but \"%s\"", "client-id-for-test", clientId) | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCheckRequestWithCredentials_Failed_WithInvalidConfig(t *testing.T) { | ||||
| 	Config = &config.Configuration{ | ||||
| 		Credentials: []config.Credential{ | ||||
| 			config.Credential{ | ||||
| 				ClientID:     "client-id-for-test", | ||||
| 				ClientSecret: "invalid-client-secret-for-test", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
| 	clientId, pass := CheckRequestWithCredentials("1684493546276", "nwBJQmaBLv9+5/sSS/66jcFc1/kGY5wo38L88LOGfRU=") | ||||
| 	if pass { | ||||
| 		t.Errorf("pass should be false, but true") | ||||
| 		return | ||||
| 	} | ||||
| 	if clientId != "" { | ||||
| 		t.Errorf("client id should be empty") | ||||
| 		return | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 金喜@DingTalk
					金喜@DingTalk