From 1686748869d1ad622f58a817aee449d38aa17372 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8C=E4=B8=AB=E8=AE=B2=E6=A2=B5?= Date: Fri, 17 Mar 2023 10:13:13 +0800 Subject: [PATCH] feat: add user request limit config (#98) --- README.md | 5 +- config.dev.json | 3 +- config/config.go | 7 ++ docker-compose.yml | 1 + go.mod | 2 +- main.go | 178 +-------------------------------- pkg/cache/user_base.go | 35 +++++++ pkg/cache/user_context.go | 22 ++++ pkg/cache/user_mode.go | 22 ++++ pkg/cache/user_requese.go | 19 ++++ pkg/chatgpt/export.go | 56 +++++++++++ pkg/process/process_request.go | 160 +++++++++++++++++++++++++++++ public/public.go | 6 +- service/user.go | 68 ------------- 14 files changed, 333 insertions(+), 251 deletions(-) create mode 100644 pkg/cache/user_base.go create mode 100644 pkg/cache/user_context.go create mode 100644 pkg/cache/user_mode.go create mode 100644 pkg/cache/user_requese.go create mode 100644 pkg/chatgpt/export.go create mode 100644 pkg/process/process_request.go delete mode 100644 service/user.go diff --git a/README.md b/README.md index 8e705f0..3b82630 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ ```sh # 运行项目 -$ docker run -itd --name chatgpt -p 8090:8090 --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest +$ docker run -itd --name chatgpt -p 8090:8090 --add-host="host.docker.internal:host-gateway" -e APIKEY=换成你的key -e BASE_URL="" -e MODEL="gpt-3.5-turbo" -e SESSION_TIMEOUT=600 -e HTTP_PROXY="http://host.docker.internal:15732" -e DEFAULT_MODE="单聊" -e MAX_REQUEST=0 --restart=always dockerproxy.com/eryajf/chatgpt-dingtalk:latest ``` `📢 注意:`如果使用docker部署,那么proxy地址可以直接使用如上方式部署,`host.docker.internal`会指向容器所在宿主机的IP,只需要更改端口为你的代理端口即可。参见:[Docker容器如何优雅地访问宿主机网络](https://wiki.eryajf.net/pages/674f53/) @@ -243,7 +243,8 @@ $ go run main.go "model": "gpt-3.5-turbo", // 指定模型,默认为 gpt-3.5-turbo , 可选参数有: "gpt-4-32k-0314", "gpt-4-32k", "gpt-4-0314", "gpt-4", "gpt-3.5-turbo-0301", "gpt-3.5-turbo", "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-001", "davinci-instruct-beta", "davinci", "curie-instruct-beta", "curie", "ada", "babbage" "session_timeout": 600, // 会话超时时间,默认600秒,在会话时间内所有发送给机器人的信息会作为上下文 "http_proxy": "", // 指定请求时使用的代理,如果为空,则不使用代理 - "default_mode": "单聊" // 默认对话模式,可根据实际场景自定义,如果不设置,默认为单聊 + "default_mode": "单聊", // 默认对话模式,可根据实际场景自定义,如果不设置,默认为单聊 + "max_request": 0 // 单人单日请求次数限制,默认为0,即不限制 } ``` diff --git a/config.dev.json b/config.dev.json index 05beb00..f6fcadc 100644 --- a/config.dev.json +++ b/config.dev.json @@ -4,5 +4,6 @@ "model": "gpt-3.5-turbo", "session_timeout": 600, "http_proxy": "", - "default_mode": "单聊" + "default_mode": "单聊", + "max_request": 0 } \ No newline at end of file diff --git a/config/config.go b/config/config.go index a2a27b5..93af72d 100644 --- a/config/config.go +++ b/config/config.go @@ -25,6 +25,8 @@ type Configuration struct { DefaultMode string `json:"default_mode"` // 代理地址 HttpProxy string `json:"http_proxy"` + // 用户单日最大请求次数 + MaxRequest int `json:"max_request"` } var config *Configuration @@ -54,6 +56,7 @@ func LoadConfig() *Configuration { sessionTimeout := os.Getenv("SESSION_TIMEOUT") defaultMode := os.Getenv("DEFAULT_MODE") httpProxy := os.Getenv("HTTP_PROXY") + maxRequest := os.Getenv("MAX_REQUEST") if apiKey != "" { config.ApiKey = apiKey } @@ -79,6 +82,10 @@ func LoadConfig() *Configuration { if model != "" { config.Model = model } + if maxRequest != "" { + newMR, _ := strconv.Atoi(maxRequest) + config.MaxRequest = newMR + } }) if config.Model == "" { config.DefaultMode = "gpt-3.5-turbo" diff --git a/docker-compose.yml b/docker-compose.yml index f7c4389..9df6545 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,6 +12,7 @@ services: SESSION_TIMEOUT: 600 # 超时时间 HTTP_PROXY: http://host.docker.internal:15777 # 配置代理,注意:host.docker.internal会解析到容器所在的宿主机IP,如果你的服务部署在宿主机,只需要更改端口即可 DEFAULT_MODE: "单聊" # 聊天模式 + MAX_REQUEST: 0 # 单人单日请求次数限制,默认为0,即不限制 ports: - "8090:8090" extra_hosts: diff --git a/go.mod b/go.mod index 9bdd11b..2d7f412 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,13 @@ module github.com/eryajf/chatgpt-dingtalk go 1.18 require ( - github.com/avast/retry-go v2.7.0+incompatible github.com/go-resty/resty/v2 v2.7.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/solywsh/chatgpt v0.0.14 ) require ( + github.com/avast/retry-go v2.7.0+incompatible // indirect github.com/sashabaranov/go-openai v1.5.1 // indirect github.com/stretchr/testify v1.8.2 // indirect golang.org/x/net v0.0.0-20211029224645-99673261e6eb // indirect diff --git a/main.go b/main.go index 7955685..8e22b27 100644 --- a/main.go +++ b/main.go @@ -6,12 +6,10 @@ import ( "io/ioutil" "net/http" "strings" - "time" - "github.com/avast/retry-go" + "github.com/eryajf/chatgpt-dingtalk/pkg/process" "github.com/eryajf/chatgpt-dingtalk/public" "github.com/eryajf/chatgpt-dingtalk/public/logger" - "github.com/solywsh/chatgpt" ) func init() { @@ -66,7 +64,7 @@ func Start() { } } else { logger.Info(fmt.Sprintf("dingtalk callback parameters: %#v", msgObj)) - err = ProcessRequest(*msgObj) + err = process.ProcessRequest(*msgObj) if err != nil { logger.Warning(fmt.Errorf("process request failed: %v", err)) } @@ -86,175 +84,3 @@ func Start() { logger.Danger(err) } } - -func ProcessRequest(rmsg public.ReceiveMsg) error { - content := strings.TrimSpace(rmsg.Text.Content) - switch content { - case "单聊": - public.UserService.SetUserMode(rmsg.SenderStaffId, content) - _, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈单聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - } - case "串聊": - public.UserService.SetUserMode(rmsg.SenderStaffId, content) - _, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈串聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - } - case "重置": - public.UserService.ClearUserMode(rmsg.SenderStaffId) - public.UserService.ClearUserSessionContext(rmsg.SenderStaffId) - _, err := rmsg.ReplyText(fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick), rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - } - case "余额": - cacheMsg := public.UserService.GetUserMode("system_balance") - if cacheMsg == "" { - rst, err := public.GetBalance() - if err != nil { - logger.Warning(fmt.Errorf("get balance error: %v", err)) - return err - } - t1 := time.Unix(int64(rst.Grants.Data[0].EffectiveAt), 0) - t2 := time.Unix(int64(rst.Grants.Data[0].ExpiresAt), 0) - cacheMsg = fmt.Sprintf("💵 已用: 💲%v\n💵 剩余: 💲%v\n⏳ 有效时间: 从 %v 到 %v\n", fmt.Sprintf("%.2f", rst.TotalUsed), fmt.Sprintf("%.2f", rst.TotalAvailable), t1.Format("2006-01-02 15:04:05"), t2.Format("2006-01-02 15:04:05")) - } - - _, err := rmsg.ReplyText(cacheMsg, rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - } - default: - if public.FirstCheck(rmsg) { - return Do("串聊", rmsg) - } else { - return Do("单聊", rmsg) - } - } - return nil -} - -func Do(mode string, rmsg public.ReceiveMsg) error { - // 先把模式注入 - public.UserService.SetUserMode(rmsg.SenderStaffId, mode) - switch mode { - case "单聊": - reply, err := SingleQa(rmsg.Text.Content, rmsg.SenderStaffId) - if err != nil { - logger.Info(fmt.Errorf("gpt request error: %v", err)) - if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") { - public.UserService.ClearUserSessionContext(rmsg.SenderStaffId) - _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - return err - } - } else { - _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - return err - } - } - } - if reply == "" { - logger.Warning(fmt.Errorf("get gpt result falied: %v", err)) - return nil - } else { - reply = strings.TrimSpace(reply) - reply = strings.Trim(reply, "\n") - // 回复@我的用户 - // fmt.Println("单聊结果是:", reply) - _, err = rmsg.ReplyText(reply, rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - return err - } - } - case "串聊": - cli, reply, err := ContextQa(rmsg.Text.Content, rmsg.SenderStaffId) - if err != nil { - logger.Info(fmt.Sprintf("gpt request error: %v", err)) - if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") { - public.UserService.ClearUserSessionContext(rmsg.SenderStaffId) - _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - return err - } - } else { - _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - return err - } - } - } - if reply == "" { - logger.Warning(fmt.Errorf("get gpt result falied: %v", err)) - return nil - } else { - reply = strings.TrimSpace(reply) - reply = strings.Trim(reply, "\n") - // 回复@我的用户 - _, err = rmsg.ReplyText(reply, rmsg.SenderStaffId) - if err != nil { - logger.Warning(fmt.Errorf("send message error: %v", err)) - return err - } - _ = cli.ChatContext.SaveConversation(rmsg.SenderStaffId) - } - default: - - } - return nil -} - -func SingleQa(question, userId string) (answer string, err error) { - chat := chatgpt.New(userId) - defer chat.Close() - // 定义一个重试策略 - retryStrategy := []retry.Option{ - retry.Delay(100 * time.Millisecond), - retry.Attempts(3), - retry.LastErrorOnly(true), - } - // 使用重试策略进行重试 - err = retry.Do( - func() error { - answer, err = chat.ChatWithContext(question) - if err != nil { - return err - } - return nil - }, - retryStrategy...) - return -} - -func ContextQa(question, userId string) (chat *chatgpt.ChatGPT, answer string, err error) { - chat = chatgpt.New(userId) - if public.UserService.GetUserSessionContext(userId) != "" { - err := chat.ChatContext.LoadConversation(userId) - if err != nil { - logger.Warning("load station failed: %v\n", err) - } - } - retryStrategy := []retry.Option{ - retry.Delay(100 * time.Millisecond), - retry.Attempts(3), - retry.LastErrorOnly(true)} - // 使用重试策略进行重试 - err = retry.Do( - func() error { - answer, err = chat.ChatWithContext(question) - if err != nil { - return err - } - return nil - }, - retryStrategy...) - return -} diff --git a/pkg/cache/user_base.go b/pkg/cache/user_base.go new file mode 100644 index 0000000..ae7cf0c --- /dev/null +++ b/pkg/cache/user_base.go @@ -0,0 +1,35 @@ +package cache + +import ( + "time" + + "github.com/patrickmn/go-cache" +) + +// UserServiceInterface 用户业务接口 +type UserServiceInterface interface { + // 用户聊天模式 + GetUserMode(userId string) string + SetUserMode(userId, mode string) + ClearUserMode(userId string) + // 用户聊天上下文 + GetUserSessionContext(userId string) string + SetUserSessionContext(userId, content string) + ClearUserSessionContext(userId string) + // 用户请求次数 + SetUseRequestCount(userId string, current int) + GetUseRequestCount(uerId string) int +} + +var _ UserServiceInterface = (*UserService)(nil) + +// UserService 用戶业务 +type UserService struct { + // 缓存 + cache *cache.Cache +} + +// NewUserService 创建新的业务层 +func NewUserService() UserServiceInterface { + return &UserService{cache: cache.New(time.Hour*2, time.Hour*5)} +} diff --git a/pkg/cache/user_context.go b/pkg/cache/user_context.go new file mode 100644 index 0000000..d76e08a --- /dev/null +++ b/pkg/cache/user_context.go @@ -0,0 +1,22 @@ +package cache + +import "github.com/patrickmn/go-cache" + +// SetUserSessionContext 设置用户会话上下文文本,question用户提问内容,GTP回复内容 +func (s *UserService) SetUserSessionContext(userId string, content string) { + s.cache.Set(userId+"_content", content, cache.DefaultExpiration) +} + +// GetUserSessionContext 获取用户会话上下文文本 +func (s *UserService) GetUserSessionContext(userId string) string { + sessionContext, ok := s.cache.Get(userId + "_content") + if !ok { + return "" + } + return sessionContext.(string) +} + +// ClearUserSessionContext 清空GTP上下文,接收文本中包含 SessionClearToken +func (s *UserService) ClearUserSessionContext(userId string) { + s.cache.Delete(userId + "_content") +} diff --git a/pkg/cache/user_mode.go b/pkg/cache/user_mode.go new file mode 100644 index 0000000..dca62b2 --- /dev/null +++ b/pkg/cache/user_mode.go @@ -0,0 +1,22 @@ +package cache + +import "github.com/patrickmn/go-cache" + +// GetUserMode 获取当前对话模式 +func (s *UserService) GetUserMode(userId string) string { + sessionContext, ok := s.cache.Get(userId + "_mode") + if !ok { + return "" + } + return sessionContext.(string) +} + +// SetUserMode 设置用户对话模式 +func (s *UserService) SetUserMode(userId string, mode string) { + s.cache.Set(userId+"_mode", mode, cache.DefaultExpiration) +} + +// ClearUserMode 重置用户对话模式 +func (s *UserService) ClearUserMode(userId string) { + s.cache.Delete(userId + "_mode") +} diff --git a/pkg/cache/user_requese.go b/pkg/cache/user_requese.go new file mode 100644 index 0000000..c83ca5a --- /dev/null +++ b/pkg/cache/user_requese.go @@ -0,0 +1,19 @@ +package cache + +import ( + "time" +) + +// SetUseRequestCount 设置用户请求次数 +func (s *UserService) SetUseRequestCount(userId string, current int) { + s.cache.Set(userId+"_request", current, time.Hour*24) +} + +// GetUseRequestCount 获取当前用户已请求次数 +func (s *UserService) GetUseRequestCount(userId string) int { + sessionContext, ok := s.cache.Get(userId + "_request") + if !ok { + return 0 + } + return sessionContext.(int) +} diff --git a/pkg/chatgpt/export.go b/pkg/chatgpt/export.go new file mode 100644 index 0000000..8ff4745 --- /dev/null +++ b/pkg/chatgpt/export.go @@ -0,0 +1,56 @@ +package chatgpt + +import ( + "time" + + "github.com/avast/retry-go" + "github.com/eryajf/chatgpt-dingtalk/public" + "github.com/eryajf/chatgpt-dingtalk/public/logger" +) + +func SingleQa(question, userId string) (answer string, err error) { + chat := New(userId) + defer chat.Close() + // 定义一个重试策略 + retryStrategy := []retry.Option{ + retry.Delay(100 * time.Millisecond), + retry.Attempts(3), + retry.LastErrorOnly(true), + } + // 使用重试策略进行重试 + err = retry.Do( + func() error { + answer, err = chat.ChatWithContext(question) + if err != nil { + return err + } + return nil + }, + retryStrategy...) + return +} + +func ContextQa(question, userId string) (chat *ChatGPT, answer string, err error) { + chat = New(userId) + if public.UserService.GetUserSessionContext(userId) != "" { + err := chat.ChatContext.LoadConversation(userId) + if err != nil { + logger.Warning("load station failed: %v\n", err) + } + } + retryStrategy := []retry.Option{ + retry.Delay(100 * time.Millisecond), + retry.Attempts(3), + retry.LastErrorOnly(true)} + // 使用重试策略进行重试 + err = retry.Do( + func() error { + answer, err = chat.ChatWithContext(question) + if err != nil { + return err + } + return nil + }, + retryStrategy...) + return +} diff --git a/pkg/process/process_request.go b/pkg/process/process_request.go new file mode 100644 index 0000000..132baac --- /dev/null +++ b/pkg/process/process_request.go @@ -0,0 +1,160 @@ +package process + +import ( + "fmt" + "strings" + "time" + + "github.com/eryajf/chatgpt-dingtalk/public" + "github.com/eryajf/chatgpt-dingtalk/public/logger" + "github.com/solywsh/chatgpt" +) + +// ProcessRequest 分析处理请求逻辑 +func ProcessRequest(rmsg public.ReceiveMsg) error { + if CheckRequest(rmsg) { + content := strings.TrimSpace(rmsg.Text.Content) + switch content { + case "单聊": + public.UserService.SetUserMode(rmsg.SenderStaffId, content) + _, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈单聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + } + case "串聊": + public.UserService.SetUserMode(rmsg.SenderStaffId, content) + _, err := rmsg.ReplyText(fmt.Sprintf("=====现在进入与👉%s👈串聊的模式 =====", rmsg.SenderNick), rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + } + case "重置": + public.UserService.ClearUserMode(rmsg.SenderStaffId) + public.UserService.ClearUserSessionContext(rmsg.SenderStaffId) + _, err := rmsg.ReplyText(fmt.Sprintf("=====已重置与👉%s👈的对话模式,可以开始新的对话=====", rmsg.SenderNick), rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + } + case "余额": + cacheMsg := public.UserService.GetUserMode("system_balance") + if cacheMsg == "" { + rst, err := public.GetBalance() + if err != nil { + logger.Warning(fmt.Errorf("get balance error: %v", err)) + return err + } + t1 := time.Unix(int64(rst.Grants.Data[0].EffectiveAt), 0) + t2 := time.Unix(int64(rst.Grants.Data[0].ExpiresAt), 0) + cacheMsg = fmt.Sprintf("💵 已用: 💲%v\n💵 剩余: 💲%v\n⏳ 有效时间: 从 %v 到 %v\n", fmt.Sprintf("%.2f", rst.TotalUsed), fmt.Sprintf("%.2f", rst.TotalAvailable), t1.Format("2006-01-02 15:04:05"), t2.Format("2006-01-02 15:04:05")) + } + + _, err := rmsg.ReplyText(cacheMsg, rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + } + default: + if public.FirstCheck(rmsg) { + return Do("串聊", rmsg) + } else { + return Do("单聊", rmsg) + } + } + } + return nil +} + +// 执行处理请求 +func Do(mode string, rmsg public.ReceiveMsg) error { + // 先把模式注入 + public.UserService.SetUserMode(rmsg.SenderStaffId, mode) + switch mode { + case "单聊": + reply, err := chatgpt.SingleQa(rmsg.Text.Content, rmsg.SenderStaffId) + if err != nil { + logger.Info(fmt.Errorf("gpt request error: %v", err)) + if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") { + public.UserService.ClearUserSessionContext(rmsg.SenderStaffId) + _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + return err + } + } else { + _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + return err + } + } + } + if reply == "" { + logger.Warning(fmt.Errorf("get gpt result falied: %v", err)) + return nil + } else { + reply = strings.TrimSpace(reply) + reply = strings.Trim(reply, "\n") + // 回复@我的用户 + // fmt.Println("单聊结果是:", reply) + _, err = rmsg.ReplyText(reply, rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + return err + } + } + case "串聊": + cli, reply, err := chatgpt.ContextQa(rmsg.Text.Content, rmsg.SenderStaffId) + if err != nil { + logger.Info(fmt.Sprintf("gpt request error: %v", err)) + if strings.Contains(fmt.Sprintf("%v", err), "maximum text length exceeded") { + public.UserService.ClearUserSessionContext(rmsg.SenderStaffId) + _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v,看起来是超过最大对话限制了,已自动重置您的对话", err), rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + return err + } + } else { + _, err = rmsg.ReplyText(fmt.Sprintf("请求openai失败了,错误信息:%v", err), rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + return err + } + } + } + if reply == "" { + logger.Warning(fmt.Errorf("get gpt result falied: %v", err)) + return nil + } else { + reply = strings.TrimSpace(reply) + reply = strings.Trim(reply, "\n") + // 回复@我的用户 + _, err = rmsg.ReplyText(reply, rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + return err + } + _ = cli.ChatContext.SaveConversation(rmsg.SenderStaffId) + } + default: + + } + return nil +} + +// ProcessRequest 分析处理请求逻辑 +func CheckRequest(rmsg public.ReceiveMsg) bool { + if public.Config.MaxRequest == 0 { + return true + } + count := public.UserService.GetUseRequestCount(rmsg.SenderStaffId) + // 判断访问次数是否超过限制 + if count >= public.Config.MaxRequest { + logger.Info(fmt.Sprintf("亲爱的: %s,您今日请求次数已达上限,请明天再来,交互发问资源有限,请务必斟酌您的问题,给您带来不便,敬请谅解!", rmsg.SenderNick)) + _, err := rmsg.ReplyText(fmt.Sprintf("一个好的问题,胜过十个好的答案!\n亲爱的: %s,您今日请求次数已达上限,请明天再来,交互发问资源有限,请务必斟酌您的问题,给您带来不便,敬请谅解!", rmsg.SenderNick), rmsg.SenderStaffId) + if err != nil { + logger.Warning(fmt.Errorf("send message error: %v", err)) + } + return false + } + // 访问次数未超过限制,将计数加1 + public.UserService.SetUseRequestCount(rmsg.SenderStaffId, count+1) + return true +} diff --git a/public/public.go b/public/public.go index 1cab1b3..ca66ca0 100644 --- a/public/public.go +++ b/public/public.go @@ -4,15 +4,15 @@ import ( "strings" "github.com/eryajf/chatgpt-dingtalk/config" - "github.com/eryajf/chatgpt-dingtalk/service" + "github.com/eryajf/chatgpt-dingtalk/pkg/cache" ) -var UserService service.UserServiceInterface +var UserService cache.UserServiceInterface var Config *config.Configuration func InitSvc() { Config = config.LoadConfig() - UserService = service.NewUserService() + UserService = cache.NewUserService() _, _ = GetBalance() } diff --git a/service/user.go b/service/user.go deleted file mode 100644 index 406901b..0000000 --- a/service/user.go +++ /dev/null @@ -1,68 +0,0 @@ -package service - -import ( - "time" - - "github.com/patrickmn/go-cache" -) - -// UserServiceInterface 用户业务接口 -type UserServiceInterface interface { - GetUserMode(userId string) string - SetUserMode(userId, mode string) - ClearUserMode(userId string) - GetUserSessionContext(userId string) string - SetUserSessionContext(userId, content string) - ClearUserSessionContext(userId string) -} - -var _ UserServiceInterface = (*UserService)(nil) - -// UserService 用戶业务 -type UserService struct { - // 缓存 - cache *cache.Cache -} - -// NewUserService 创建新的业务层 -func NewUserService() UserServiceInterface { - return &UserService{cache: cache.New(time.Hour*2, time.Hour*5)} -} - -// GetUserMode 获取当前对话模式 -func (s *UserService) GetUserMode(userId string) string { - sessionContext, ok := s.cache.Get(userId + "_mode") - if !ok { - return "" - } - return sessionContext.(string) -} - -// SetUserMode 设置用户对话模式 -func (s *UserService) SetUserMode(userId string, mode string) { - s.cache.Set(userId+"_mode", mode, cache.DefaultExpiration) -} - -// ClearUserMode 重置用户对话模式 -func (s *UserService) ClearUserMode(userId string) { - s.cache.Delete(userId + "_mode") -} - -// SetUserSessionContext 设置用户会话上下文文本,question用户提问内容,GTP回复内容 -func (s *UserService) SetUserSessionContext(userId string, content string) { - s.cache.Set(userId+"_content", content, cache.DefaultExpiration) -} - -// GetUserSessionContext 获取用户会话上下文文本 -func (s *UserService) GetUserSessionContext(userId string) string { - sessionContext, ok := s.cache.Get(userId + "_content") - if !ok { - return "" - } - return sessionContext.(string) -} - -// ClearUserSessionContext 清空GTP上下文,接收文本中包含 SessionClearToken -func (s *UserService) ClearUserSessionContext(userId string) { - s.cache.Delete(userId + "_content") -}