From acb5462a01cd3d2dc0234a9a86d0ec6ac7f390f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8C=E4=B8=AB=E8=AE=B2=E6=A2=B5?= Date: Tue, 7 Mar 2023 13:29:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=8C=87=E5=AE=9A?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=9A=84=E8=83=BD=E5=8A=9B=20(#74)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .pre-commit-config.yaml | 12 +------ README.md | 4 ++- config.dev.json | 1 + config/config.go | 6 ++++ main.go | 7 ++-- pkg/chatgpt/context.go | 75 ++++++++++++++++++++++++++++------------- public/gpt.go | 7 ++-- public/public.go | 5 +-- 8 files changed, 71 insertions(+), 46 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7d65668..c3dc9ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,14 +3,4 @@ repos: rev: v4.4.0 hooks: - id: check-yaml - - id: check-added-large-files -- repo: https://github.com/golangci/golangci-lint # golangci-lint hook repo - rev: v1.47.3 # golangci-lint hook repo revision - hooks: - - id: golangci-lint - name: golangci-lint - description: Fast linters runner for Go. - entry: golangci-lint run --fix - types: [go] - language: golang - pass_filenames: false \ No newline at end of file + - id: check-added-large-files \ No newline at end of file diff --git a/README.md b/README.md index 04990db..2c9e5f8 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ - 支持在钉钉群聊中添加机器人,通过@机器人进行聊天交互。 - 提问支持单聊与串聊两种模式,通过@机器人发关键字切换。 - 支持添加代理,通过配置化指定。 +- 支持自定义指定的模型,通过配置化指定。 - 支持自定义默认的聊天模式,通过配置化指定。 ## 使用前提 @@ -78,7 +79,7 @@ ```sh # 运行项目 -$ docker run -itd --name chatgpt -p 8090:8090 -e APIKEY=换成你的key -e SESSION_TIMEOUT=600 -e HTTP_PROXY="" -e DEFAULT_MODE="单聊" --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest +$ docker run -itd --name chatgpt -p 8090:8090 -e APIKEY=换成你的key -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="" -e DEFAULT_MODE="单聊" --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest ``` `📢 注意:`如果你使用docker部署,那么proxy指定地址的时候,请指定宿主机的IP,而不要写成127,以免代理不生效。 @@ -221,6 +222,7 @@ $ go run main.go ```json { "api_key": "xxxxxxxxx", // openai api_key + "model": "gpt-3.5-turbo", // 指定模型,默认为 gpt-3.5-turbo ,具体选项参考官网训练场 "session_timeout": 600, // 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文 "http_proxy": "", // 指定请求时使用的代理,如果为空,则不使用代理 "default_mode": "单聊" // 默认对话模式,可根据实际场景自定义,如果不设置,默认为单聊 diff --git a/config.dev.json b/config.dev.json index b690ec5..7b6b3f8 100644 --- a/config.dev.json +++ b/config.dev.json @@ -1,5 +1,6 @@ { "api_key": "xxxxxxxxx", + "model": "gpt-3.5-turbo", "session_timeout": 600, "http_proxy": "", "default_mode": "单聊" diff --git a/config/config.go b/config/config.go index 64393d7..2f9f8d0 100644 --- a/config/config.go +++ b/config/config.go @@ -15,6 +15,8 @@ import ( type Configuration struct { // gtp apikey ApiKey string `json:"api_key"` + // 使用模型 + Model string `json:"model"` // 会话超时时间 SessionTimeout time.Duration `json:"session_timeout"` // 默认对话模式 @@ -45,6 +47,7 @@ func LoadConfig() *Configuration { } // 如果环境变量有配置,读取环境变量 ApiKey := os.Getenv("APIKEY") + model := os.Getenv("MODEL") SessionTimeout := os.Getenv("SESSION_TIMEOUT") defaultMode := os.Getenv("DEFAULT_MODE") httpProxy := os.Getenv("HTTP_PROXY") @@ -67,6 +70,9 @@ func LoadConfig() *Configuration { if httpProxy != "" { config.HttpProxy = httpProxy } + if model != "" { + config.Model = model + } }) if config.DefaultMode == "" { config.DefaultMode = "单聊" diff --git a/main.go b/main.go index 9dfd6ef..316c0a9 100644 --- a/main.go +++ b/main.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/eryajf/chatgpt-dingtalk/config" "github.com/eryajf/chatgpt-dingtalk/public" "github.com/eryajf/chatgpt-dingtalk/public/logger" "github.com/solywsh/chatgpt" @@ -213,15 +212,13 @@ func Do(mode string, rmsg public.ReceiveMsg) error { } func SingleQa(question, userId string) (answer string, err error) { - cfg := config.LoadConfig() - chat := chatgpt.New(cfg.ApiKey, cfg.HttpProxy, userId, cfg.SessionTimeout) + chat := chatgpt.New(public.Config.ApiKey, public.Config.HttpProxy, userId, public.Config.SessionTimeout) defer chat.Close() return chat.ChatWithContext(question) } func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) { - cfg := config.LoadConfig() - chat = chatgpt.New(cfg.ApiKey, cfg.HttpProxy, userId, cfg.SessionTimeout) + chat = chatgpt.New(public.Config.ApiKey, public.Config.HttpProxy, userId, public.Config.SessionTimeout) if public.UserService.GetUserSessionContext(userId) != "" { err = chat.ChatContext.LoadConversation(userId) if err != nil { diff --git a/pkg/chatgpt/context.go b/pkg/chatgpt/context.go index 77e3d32..9416aac 100644 --- a/pkg/chatgpt/context.go +++ b/pkg/chatgpt/context.go @@ -156,31 +156,60 @@ func (c *ChatGPT) ChatWithContext(question string) (answer string, err error) { if len(prompt) > c.maxText-c.maxAnswerLen { return "", OverMaxTextLength } - c1 := gogpt.ChatCompletionRequest{ - Model: gogpt.GPT3Dot5Turbo, - Messages: []gogpt.ChatCompletionMessage{ - { - Role: "user", - Content: prompt, - }, - }} - req := c1 - resp, err := c.client.CreateChatCompletion(c.ctx, req) - if err != nil { - return "", err + if public.Config.Model == gogpt.GPT3Dot5Turbo0301 || public.Config.Model == gogpt.GPT3Dot5Turbo { + req := gogpt.ChatCompletionRequest{ + Model: public.Config.Model, + Messages: []gogpt.ChatCompletionMessage{ + { + Role: "user", + Content: prompt, + }, + }} + resp, err := c.client.CreateChatCompletion(c.ctx, req) + if err != nil { + return "", err + } + resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content) + c.ChatContext.old = append(c.ChatContext.old, conversation{ + Role: c.ChatContext.humanRole, + Prompt: question, + }) + c.ChatContext.old = append(c.ChatContext.old, conversation{ + Role: c.ChatContext.aiRole, + Prompt: resp.Choices[0].Message.Content, + }) + c.ChatContext.seqTimes++ + return resp.Choices[0].Message.Content, nil + } else { + req := gogpt.CompletionRequest{ + Model: public.Config.Model, + MaxTokens: c.maxAnswerLen, + Prompt: prompt, + Temperature: 0.9, + TopP: 1, + N: 1, + FrequencyPenalty: 0, + PresencePenalty: 0.5, + User: c.userId, + Stop: []string{c.ChatContext.aiRole.Name + ":", c.ChatContext.humanRole.Name + ":"}, + } + resp, err := c.client.CreateCompletion(c.ctx, req) + if err != nil { + return "", err + } + resp.Choices[0].Text = formatAnswer(resp.Choices[0].Text) + c.ChatContext.old = append(c.ChatContext.old, conversation{ + Role: c.ChatContext.humanRole, + Prompt: question, + }) + c.ChatContext.old = append(c.ChatContext.old, conversation{ + Role: c.ChatContext.aiRole, + Prompt: resp.Choices[0].Text, + }) + c.ChatContext.seqTimes++ + return resp.Choices[0].Text, nil } - resp.Choices[0].Message.Content = formatAnswer(resp.Choices[0].Message.Content) - c.ChatContext.old = append(c.ChatContext.old, conversation{ - Role: c.ChatContext.humanRole, - Prompt: question, - }) - c.ChatContext.old = append(c.ChatContext.old, conversation{ - Role: c.ChatContext.aiRole, - Prompt: resp.Choices[0].Message.Content, - }) - c.ChatContext.seqTimes++ - return resp.Choices[0].Message.Content, nil } func WithMaxSeqTimes(times int) ChatContextOption { diff --git a/public/gpt.go b/public/gpt.go index 975d80c..81a0852 100644 --- a/public/gpt.go +++ b/public/gpt.go @@ -5,15 +5,14 @@ import ( "fmt" "time" - "github.com/eryajf/chatgpt-dingtalk/config" "github.com/go-resty/resty/v2" ) func InitAiCli() *resty.Client { - if config.LoadConfig().HttpProxy != "" { - return resty.New().SetTimeout(30*time.Second).SetHeader("Authorization", fmt.Sprintf("Bearer %s", config.LoadConfig().ApiKey)).SetProxy(config.LoadConfig().HttpProxy).SetRetryCount(3).SetRetryWaitTime(5 * time.Second) + if Config.HttpProxy != "" { + return resty.New().SetTimeout(30*time.Second).SetHeader("Authorization", fmt.Sprintf("Bearer %s", Config.ApiKey)).SetProxy(Config.HttpProxy).SetRetryCount(3).SetRetryWaitTime(5 * time.Second) } - return resty.New().SetTimeout(30*time.Second).SetHeader("Authorization", fmt.Sprintf("Bearer %s", config.LoadConfig().ApiKey)).SetRetryCount(3).SetRetryWaitTime(5 * time.Second) + return resty.New().SetTimeout(30*time.Second).SetHeader("Authorization", fmt.Sprintf("Bearer %s", Config.ApiKey)).SetRetryCount(3).SetRetryWaitTime(5 * time.Second) } type Billing struct { diff --git a/public/public.go b/public/public.go index e76bad5..1cab1b3 100644 --- a/public/public.go +++ b/public/public.go @@ -8,9 +8,10 @@ import ( ) var UserService service.UserServiceInterface +var Config *config.Configuration func InitSvc() { - config.LoadConfig() + Config = config.LoadConfig() UserService = service.NewUserService() _, _ = GetBalance() } @@ -18,7 +19,7 @@ func InitSvc() { func FirstCheck(rmsg ReceiveMsg) bool { lc := UserService.GetUserMode(rmsg.SenderStaffId) if lc == "" { - if config.LoadConfig().DefaultMode == "串聊" { + if Config.DefaultMode == "串聊" { return true } else { return false